Tensorflow.js:expandDims() + cast() 问题
Tensorflow.js: expandDims() + cast() issue
我需要为我的模型提供一个 float32
张量。我需要使用 expandDims(tensor, axis=0)
将其形状从 [240, 320, 3]
更改为 [1, 240, 320, 3]
。但是,expandDims()
操作似乎将我的张量转换为 int32
.
当我对这个张量执行 cast(tensor, "float32")
时,似乎 cast
操作将我的张量压缩回 [240, 320, 3]
。
image_v = (tf.cast(image_raw, "float32") / 255.0) - 0.5;
image_v = tf.expandDims(image_raw, 0);
console.log(image_v.shape) // shape: [1, 240, 320, 3]
console.log(image_v.dtype) // dtype: "int32"
image_v = tf.cast(image_raw, "float32")
console.log(image_v.shape) // shape: [240, 320, 3]
console.log(image_v.dtype) // dtype: "float32"
我正在寻找一种方法来扩展 tensorflow.js
中 float32
张量的 dims,并使张量的 dtype
保持 float32
。任何帮助将不胜感激!
tfjs不能用JS对tensor进行运算,得用tf.div()
和tf.sub()
.
image_v = (tf.cast(image_raw, "float32") / 255.0) - 0.5;
image_v
现在是 NaN
,因为 ({}/255)-0.5 === NaN
image_v = tf.expandDims(image_raw, 0);
现在您展开未修改的原始图像。
image_v = tf.cast(image_raw, "float32")
您重复使用了原始 image_raw
,因为在 tf 操作中不修改张量。他们总是创造一个新的。
并且我建议不要重复使用变量或在 tf.tidy()
之外工作,因为您很容易忘记 .dispose()
,从而造成内存泄漏。
const image_raw = tf.zeros([240, 320, 3]);
const modified = tf.tidy(() => {
const image_v_casted = tf.cast(image_raw, "float32").div(255).sub(0.5);
const image_v_expanded = tf.expandDims(image_v_casted, 0);
return image_v_expanded;
});
console.log('shape', modified.shape);
console.log('dtype', modified.dtype);
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@1.0.0/dist/tf.min.js"></script>
你似乎认为,像 tf.cast
这样的操作对原始张量有效。但事实并非如此。引用文档:
A tf.Tensor object represents an immutable, multidimensional array of numbers that has a shape and a data type.
这意味着,每当您调用 tf.cast
或 tf.expandDims
之类的函数时,都会创建一个新的张量。因此,您的 tf.cast(image_raw, "float32")
调用不会更改原始张量,而是创建一个新张量。
解决方案
要转换张量,您应该使用 image_v
而不是 image_raw
,因为后者的张量从未更改过。
image_v = tf.cast(image_raw, "float32");
此外,您的脚本开头因其他原因无法正常工作(尽管这与您面临的问题无关)。正常的 JavaScript 操作不适用于张量。请查看 tf.div
and tf.sub
。
所以,总而言之,您的代码应该如下所示:
image_v = tf.cast(image_raw, "float32").div(255).sub(0.5);
image_v = tf.expandDims(image_v, 0);
console.log(image_v.shape) // should now be: [1, 240, 320, 3]
console.log(image_v.dtype) // and this should be: dtype: "int32"
就像已经建议的其他答案一样,您还应该查看 tf.tidy
以防止内存泄漏。
我需要为我的模型提供一个 float32
张量。我需要使用 expandDims(tensor, axis=0)
将其形状从 [240, 320, 3]
更改为 [1, 240, 320, 3]
。但是,expandDims()
操作似乎将我的张量转换为 int32
.
当我对这个张量执行 cast(tensor, "float32")
时,似乎 cast
操作将我的张量压缩回 [240, 320, 3]
。
image_v = (tf.cast(image_raw, "float32") / 255.0) - 0.5;
image_v = tf.expandDims(image_raw, 0);
console.log(image_v.shape) // shape: [1, 240, 320, 3]
console.log(image_v.dtype) // dtype: "int32"
image_v = tf.cast(image_raw, "float32")
console.log(image_v.shape) // shape: [240, 320, 3]
console.log(image_v.dtype) // dtype: "float32"
我正在寻找一种方法来扩展 tensorflow.js
中 float32
张量的 dims,并使张量的 dtype
保持 float32
。任何帮助将不胜感激!
tfjs不能用JS对tensor进行运算,得用tf.div()
和tf.sub()
.
image_v = (tf.cast(image_raw, "float32") / 255.0) - 0.5;
image_v
现在是 NaN
,因为 ({}/255)-0.5 === NaN
image_v = tf.expandDims(image_raw, 0);
现在您展开未修改的原始图像。
image_v = tf.cast(image_raw, "float32")
您重复使用了原始 image_raw
,因为在 tf 操作中不修改张量。他们总是创造一个新的。
并且我建议不要重复使用变量或在 tf.tidy()
之外工作,因为您很容易忘记 .dispose()
,从而造成内存泄漏。
const image_raw = tf.zeros([240, 320, 3]);
const modified = tf.tidy(() => {
const image_v_casted = tf.cast(image_raw, "float32").div(255).sub(0.5);
const image_v_expanded = tf.expandDims(image_v_casted, 0);
return image_v_expanded;
});
console.log('shape', modified.shape);
console.log('dtype', modified.dtype);
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@1.0.0/dist/tf.min.js"></script>
你似乎认为,像 tf.cast
这样的操作对原始张量有效。但事实并非如此。引用文档:
A tf.Tensor object represents an immutable, multidimensional array of numbers that has a shape and a data type.
这意味着,每当您调用 tf.cast
或 tf.expandDims
之类的函数时,都会创建一个新的张量。因此,您的 tf.cast(image_raw, "float32")
调用不会更改原始张量,而是创建一个新张量。
解决方案
要转换张量,您应该使用 image_v
而不是 image_raw
,因为后者的张量从未更改过。
image_v = tf.cast(image_raw, "float32");
此外,您的脚本开头因其他原因无法正常工作(尽管这与您面临的问题无关)。正常的 JavaScript 操作不适用于张量。请查看 tf.div
and tf.sub
。
所以,总而言之,您的代码应该如下所示:
image_v = tf.cast(image_raw, "float32").div(255).sub(0.5);
image_v = tf.expandDims(image_v, 0);
console.log(image_v.shape) // should now be: [1, 240, 320, 3]
console.log(image_v.dtype) // and this should be: dtype: "int32"
就像已经建议的其他答案一样,您还应该查看 tf.tidy
以防止内存泄漏。