Trim 变长张量到最大长度
Trim variable-length tensor to a maximum length
具有一个固定维度和一个可变长度维度的二维张量:如何将可变长度维度限制为最大长度?如果可变长度较短,则应保留最大值(而不是填充),但如果较长,则应截断末尾。
例如,假设所有张量的形状都是 (None, 4)
,我想将它们全部限制在 (3, 4)
的最大形状内。一个示例输入可以是:
tensor1 = tf.constant([
[1, 2, 0, 0],
[1, 3, 4, 0],
[0, 0, 0, 0],
[7, 7, 7, 7],
[7, 8, 9, 1],
], dtype=tf.int32)
...,应修剪为:
tensor1_trimmed = tf.constant([
[1, 2, 0, 0],
[1, 3, 4, 0],
[0, 0, 0, 0],
], dtype=tf.int32)
但是,小于最大值的任何值都应保持不变:
tensor2 = tf.constant([
[9, 9, 9, 9],
[9, 9, 9, 9],
], dtype=tf.int32)
...应该保持完全相同:
tensor2_trimmed = tf.constant([
[9, 9, 9, 9],
[9, 9, 9, 9],
], dtype=tf.int32)
是否有任何内置命令可以做到这一点?或者您将如何实现这一目标?
tf.strided_slice
支持 numpy 样式的切片,因此您可以在示例中使用 [:3,:]
>>> tensor1 = tf.constant([
... [1, 2, 0, 0],
... [1, 3, 4, 0],
... [0, 0, 0, 0],
... [7, 7, 7, 7],
... [7, 8, 9, 1],
... ], dtype=tf.int32)
>>> tensor1[:3,:]
<tf.Tensor: shape=(3, 4), dtype=int32, numpy=
array([[1, 2, 0, 0],
[1, 3, 4, 0],
[0, 0, 0, 0]], dtype=int32)>
>>> tensor2 = tf.constant([
... [9, 9, 9, 9],
... [9, 9, 9, 9],
... ], dtype=tf.int32)
>>> tensor2[:3,:]
<tf.Tensor: shape=(2, 4), dtype=int32, numpy=
array([[9, 9, 9, 9],
[9, 9, 9, 9]], dtype=int32)>
具有一个固定维度和一个可变长度维度的二维张量:如何将可变长度维度限制为最大长度?如果可变长度较短,则应保留最大值(而不是填充),但如果较长,则应截断末尾。
例如,假设所有张量的形状都是 (None, 4)
,我想将它们全部限制在 (3, 4)
的最大形状内。一个示例输入可以是:
tensor1 = tf.constant([
[1, 2, 0, 0],
[1, 3, 4, 0],
[0, 0, 0, 0],
[7, 7, 7, 7],
[7, 8, 9, 1],
], dtype=tf.int32)
...,应修剪为:
tensor1_trimmed = tf.constant([
[1, 2, 0, 0],
[1, 3, 4, 0],
[0, 0, 0, 0],
], dtype=tf.int32)
但是,小于最大值的任何值都应保持不变:
tensor2 = tf.constant([
[9, 9, 9, 9],
[9, 9, 9, 9],
], dtype=tf.int32)
...应该保持完全相同:
tensor2_trimmed = tf.constant([
[9, 9, 9, 9],
[9, 9, 9, 9],
], dtype=tf.int32)
是否有任何内置命令可以做到这一点?或者您将如何实现这一目标?
tf.strided_slice
支持 numpy 样式的切片,因此您可以在示例中使用 [:3,:]
>>> tensor1 = tf.constant([
... [1, 2, 0, 0],
... [1, 3, 4, 0],
... [0, 0, 0, 0],
... [7, 7, 7, 7],
... [7, 8, 9, 1],
... ], dtype=tf.int32)
>>> tensor1[:3,:]
<tf.Tensor: shape=(3, 4), dtype=int32, numpy=
array([[1, 2, 0, 0],
[1, 3, 4, 0],
[0, 0, 0, 0]], dtype=int32)>
>>> tensor2 = tf.constant([
... [9, 9, 9, 9],
... [9, 9, 9, 9],
... ], dtype=tf.int32)
>>> tensor2[:3,:]
<tf.Tensor: shape=(2, 4), dtype=int32, numpy=
array([[9, 9, 9, 9],
[9, 9, 9, 9]], dtype=int32)>