如何在 tensorflow/keras 中移动像 pandas.shift 这样的张量? (没有将最后一行移动到第一行,如 tf.roll)

How to shift a tensor like pandas.shift in tensorflow / keras? (Without shift the last row to first row, like tf.roll)

我想在给定轴上移动一个张量。在 pandas 或 numpy 中很容易做到这一点。像这样:

import numpy as np
import pandas as pd

data = np.arange(0, 6).reshape(-1, 2)
pd.DataFrame(data).shift(1).fillna(0).values

输出为:

array([[0., 0.],
[0., 1.],
[2., 3.]])

但是在tensorflow中,我找到的最接近的解决方案是tf.roll。但它 将最后一行移到第一行。 (我不想要那个)。所以我必须使用类似

的东西

tf.roll + tf.slice(remove the last row) + tf.concat(add tf.zeros to the first row).

真的.

有没有更好的方法来处理 shift 在 tensorflow 或 keras 中?

谢谢。

试试这个:

import tensorflow as tf
input = tf.constant([[0, 1, 3], [4, 5, 6], [7, 8, 9]])
shifted_0dim = input[1:]
shifted_1dim = input[:, 1:]
shifted2 = input[2:]

我想我找到了解决这个问题的更好方法。

我们可以使用 tf.roll,然后应用 tf.math.multiply 将第一行设置为零。

示例代码如下:

原始张量:

A = tf.cast(tf.reshape(tf.range(27), (-1, 3, 3)), dtype=tf.float32)
A

输出:

<tf.Tensor: id=117, shape=(3, 3, 3), dtype=float32, numpy=
array([[[ 0.,  1.,  2.],
        [ 3.,  4.,  5.],
        [ 6.,  7.,  8.]],

       [[ 9., 10., 11.],
        [12., 13., 14.],
        [15., 16., 17.]],

       [[18., 19., 20.],
        [21., 22., 23.],
        [24., 25., 26.]]], dtype=float32)>

Shift(如pd.shift):

B = tf.concat((tf.zeros((1, 3)), tf.ones((2, 3))), axis=0)
C = tf.expand_dims(B, axis=0)
tf.math.multiply(tf.roll(A, 1, axis=1), C)

输出:

<tf.Tensor: id=128, shape=(3, 3, 3), dtype=float32, numpy=
array([[[ 0.,  0.,  0.],
        [ 0.,  1.,  2.],
        [ 3.,  4.,  5.]],

       [[ 0.,  0.,  0.],
        [ 9., 10., 11.],
        [12., 13., 14.]],

       [[ 0.,  0.,  0.],
        [18., 19., 20.],
        [21., 22., 23.]]], dtype=float32)>

假设一个 2d 张量,这个函数应该模拟一个 Dataframe 移位:

def shift_tensor(tensor, periods, fill_value):
    num_row = len(tensor)
    num_col = len(tensor[0])
    pad = tf.fill([periods, num_col], fill_value)
    
    if periods > 0:
        shifted_tensor = tf.concat((pad, tensor[:(num_row - periods), :]), axis=0)
    else:
        shifted_tensor = tf.concat((tensor[:(num_row - periods), :], pad), axis=0)
    
    return shifted_tensor

泛化为任意张量形状、所需的位移和要位移的轴:

import tensorflow as tf

def tf_shift(tensor, shift=1, axis=0):
    dim = len(tensor.shape)

    if axis > dim:
        raise ValueError(
            f'Value of axis ({axis}) must be <= number of tensor axes ({dim})'
        )

    mask_dim = dim - axis
    mask_shape = tensor.shape[-mask_dim:]
    zero_dim = min(shift, mask_shape[0])

    mask = tf.concat(
        [tf.zeros(tf.TensorShape(zero_dim) + mask_shape[1:]),
         tf.ones(tf.TensorShape(mask_shape[0] - zero_dim) + mask_shape[1:])],
        axis=0
    )

    for i in range(dim - mask_dim):
        mask = tf.expand_dims(mask, axis=0)

    return tf.multiply(
        tf.roll(tensor, shift, axis),
        mask
    )

编辑: 上面的这段代码不允许负移位值,而且速度很慢。这是一个更有效的版本,利用 tf.rolltf.concat,无需创建掩码并将感兴趣的张量乘以它。

import tensorflow as tf

def tf_shift(values: tf.Tensor, shift: int = 1, axis: int = 0):
    pad = tf.zeros([val if i != axis else abs(shift) for i, val in enumerate(values.shape)],
                   dtype=values.dtype)
    size = [-1 if i != axis else val - abs(shift) for i, val in enumerate(values.shape)]

    if shift > 0:
        shifted = tf.concat(
            [pad, tf.slice(values, [0] * len(values.shape), size)],
            axis=axis
        )
    elif shift < 0:
        shifted = tf.concat(
            [tf.slice(values, [0 if i != axis else abs(shift) for i, _ in enumerate(values.shape)], size), pad],
            axis=axis
        )
    else:
        shifted = values

    return shifted