为每个值序列生成较小的张量

Generate smaller tensors for each sequence of values

考虑这个张量。

a = tf.constant([0,1,2,3,5,6,7,8,9,10,19,20,21,22,23,24])

我想将它分成 3 个张量(对于这个特定示例),包含数字紧邻的组。预期输出为:

output_tensor = [ [0,1,2,3], [5,6,7,8,9,10], [19,20,21,22,23,24] ]

知道怎么做吗?是否有张量流 .math 方法可以帮助有效地做到这一点?我找不到任何东西。

对于提供的示例,split 应该有效:

    a = tf.constant([0, 1, 2, 3, 5, 6, 7, 8, 9, 10, 19, 20, 21, 22, 23, 24])
    print(tf.split(a, [4, 6, 6]))

输出:

[<tf.Tensor: shape=(4,), dtype=int32, numpy=array([0, 1, 2, 3], dtype=int32)>, <tf.Tensor: shape=(6,), dtype=int32, numpy=array([ 5,  6,  7,  8,  9, 10], dtype=int32)>, <tf.Tensor: shape=(6,), dtype=int32, numpy=array([19, 20, 21, 22, 23, 24], dtype=int32)>]

第二个参数指定每个输出张量沿分割轴的大小(默认情况下为 0)- 因此在这种情况下,第一个张量的大小为 4,第二个张量的大小为 6,第三个张量大小为 6。或者,可以提供一个 int,只要您拆分的轴上张量的大小可以被该值整除。在这种情况下,3 将不起作用 (16/3 = 5.3333),但 4 将:

    a = tf.constant([0, 1, 2, 3, 5, 6, 7, 8, 9, 10, 19, 20, 21, 22, 23, 24])
    print(tf.split(a, 4))

输出:

[<tf.Tensor: shape=(4,), dtype=int32, numpy=array([0, 1, 2, 3], dtype=int32)>, <tf.Tensor: shape=(4,), dtype=int32, numpy=array([5, 6, 7, 8], dtype=int32)>, <tf.Tensor: shape=(4,), dtype=int32, numpy=array([ 9, 10, 19, 20], dtype=int32)>, <tf.Tensor: shape=(4,), dtype=int32, numpy=array([21, 22, 23, 24], dtype=int32)>]

假设数字连续的描述未知,可以使用相邻差异有效地计算索引并提供给 tf.split:

def compute_split_indices(x):
    adjacent_diffs = x[1:] - x[:-1]  # compute adjacent differences
    indices_where_not_continuous = tf.where(adjacent_diffs > 1) + 1
    splits = tf.concat([indices_where_not_continuous[:1], indices_where_not_continuous[1:] -
                        indices_where_not_continuous[:-1]], axis=0)  # compute split sizes from the indices
    splits_as_ints = [split.numpy().tolist()[0] for split in splits]  # convert to a list of integers for ease of use
    final_split_sizes = splits_as_ints + [len(x) - sum(splits_as_ints)]  # account for the rest of the tensor
    return final_split_sizes

if __name__ == "__main__":
    a = tf.constant([0, 1, 2, 3, 5, 6, 7, 8, 9, 10, 19, 20, 21, 22, 23, 24])
    splits = compute_split_indices(a)
    print(tf.split(a, splits))

输出:


[<tf.Tensor: shape=(4,), dtype=int32, numpy=array([0, 1, 2, 3], dtype=int32)>, <tf.Tensor: shape=(6,), dtype=int32, numpy=array([ 5,  6,  7,  8,  9, 10], dtype=int32)>, <tf.Tensor: shape=(6,), dtype=int32, numpy=array([19, 20, 21, 22, 23, 24], dtype=int32)>]

注意输出与我们显式提供 [4, 6, 6].

时的输出相同