tf.TensorArray 作为先进先出?

tf.TensorArray as a FIFO?

Here 有人指出我使用 tf.TensorArray 而不是 tf.Variabletf.queue.FIFOQueue 来使 FIFO 包含在自定义层中。这是一种有效的方法吗?这里有其他选择吗?

如果这是最有效的方法,我怎样才能用 tf.TensorArray 的方法替换 self.queue.assign(tf.concat([self.queue[timesteps:, :], inputs], axis=0))

代码

class FIFOLayer(Layer):
    def __init__(self, window_size, **kwargs):
        super(FIFOLayer, self).__init__(**kwargs)

        self.window_size = window_size
        self.count = 0

    def build(self, input_shape):
        super(FIFOLayer, self).build(input_shape)

        self.queue = self.add_weight(
            name="queue",
            shape=(self.window_size, input_shape[-1]),
            initializer=tf.initializers.Constant(value=np.nan),
            trainable=False,
        )

    def call(self, inputs, training):
        timesteps = tf.shape(inputs)[0]

        # check if batch_size is more than queue capacity
        if timesteps > self.window_size:
            raise ValueError()

        # 1. append new state to queue
        self.queue.assign(tf.concat([self.queue[timesteps:, :], inputs], axis=0))
        self.count += timesteps

        # 2. feed-forward
        if self.count < self.window_size:
            # generate mask
            attention_mask = tf.cast(
                tf.math.reduce_all(
                    tf.math.logical_not(tf.math.is_nan(self.queue)), axis=-1
                ),
                dtype=tf.float32,
            )
            attention_mask = tf.matmul(
                attention_mask[..., tf.newaxis],
                attention_mask[..., tf.newaxis],
                transpose_b=True,
            )
            return self.queue[tf.newaxis, ...], attention_mask
        # !!! check overflow
        elif self.count > self.window_size:
            self.count = self.window_size

        return self.queue[tf.newaxis, ...], None

    @property
    def is_full(self):
        return self.count == self.window_size

    def clear(self):
        self.count = 0
        self.queue.assign(tf.fill(self.queue.shape, np.nan))


l = FIFOLayer(window_size=10)
for i in range(6):
    x = tf.random.normal((2, 12))
    y = l(x)
    print(y)

print(l.is_full, "\n\n")

l.clear()

print(l(x))
print(l.is_full, "\n\n")

使用 tf.TensorArray,您可以尝试这样的操作:

import tensorflow as tf
import numpy as np
tf.random.set_seed(111)

class FIFOLayer(tf.keras.layers.Layer):
    def __init__(self, window_size, **kwargs):
        super(FIFOLayer, self).__init__(**kwargs)

        self.window_size = window_size
        self.count = 0

    def build(self, input_shape):
        super(FIFOLayer, self).build(input_shape)

        self.queue_array = tf.TensorArray(dtype=tf.float32, size=self.window_size)
        self.queue_array = self.queue_array.scatter(tf.range(self.window_size), tf.constant(np.nan)*tf.ones((self.window_size, input_shape[-1])))

    def call(self, inputs, training):
        timesteps = tf.shape(inputs)[0]

        # check if batch_size is more than queue capacity
        if timesteps > self.window_size:
            raise ValueError()

        self.queue_array = self.queue_array.scatter(tf.range(self.window_size), tf.concat([self.queue_array.gather(tf.range(timesteps, self.window_size)), inputs], axis=0))
        queue_tensor = self.queue_array.stack()
        self.count += timesteps
        # 2. feed-forward
        if self.count < self.window_size:
            # generate mask
            attention_mask = tf.cast(
                tf.math.reduce_all(
                    tf.math.logical_not(tf.math.is_nan(queue_tensor)), axis=-1
                ),
                dtype=tf.float32,
            )
            attention_mask = tf.matmul(
                attention_mask[..., tf.newaxis],
                attention_mask[..., tf.newaxis],
                transpose_b=True,
            )
            return queue_tensor[tf.newaxis, ...], attention_mask
        # !!! check overflow
        elif self.count > self.window_size:
            self.count = self.window_size

        return queue_tensor[tf.newaxis, ...], None

    @property
    def is_full(self):
        return self.count == self.window_size

    def clear(self):
        self.count = 0
        shape = tf.shape(self.queue_array.stack())[-1]
        self.queue_array = tf.TensorArray(dtype=tf.float32, size=self.window_size)
        self.queue_array = self.queue_array.scatter(tf.range(self.window_size), tf.constant(np.nan)*tf.ones((self.window_size, shape)))

l = FIFOLayer(window_size=10)
for i in range(6):
    x = tf.random.normal((2, 12))
    y = l(x)
    print(y)

print(l.is_full, "\n\n")

l.clear()

print(l(x))
print(l.is_full, "\n\n")
(<tf.Tensor: shape=(1, 10, 12), dtype=float32, numpy=
array([[[        nan,         nan,         nan,         nan,
                 nan,         nan,         nan,         nan,
                 nan,         nan,         nan,         nan],
        [        nan,         nan,         nan,         nan,
                 nan,         nan,         nan,         nan,
                 nan,         nan,         nan,         nan],
        [        nan,         nan,         nan,         nan,
                 nan,         nan,         nan,         nan,
                 nan,         nan,         nan,         nan],
        [        nan,         nan,         nan,         nan,
                 nan,         nan,         nan,         nan,
                 nan,         nan,         nan,         nan],
        [        nan,         nan,         nan,         nan,
                 nan,         nan,         nan,         nan,
                 nan,         nan,         nan,         nan],
        [        nan,         nan,         nan,         nan,
                 nan,         nan,         nan,         nan,
                 nan,         nan,         nan,         nan],
        [        nan,         nan,         nan,         nan,
                 nan,         nan,         nan,         nan,
                 nan,         nan,         nan,         nan],
        [        nan,         nan,         nan,         nan,
                 nan,         nan,         nan,         nan,
                 nan,         nan,         nan,         nan],
        [ 0.7558127 ,  1.5447265 ,  1.6315602 , -0.19868968,
          0.08828261,  0.01711658, -1.8133892 ,  0.12930395,
          0.47128937,  0.08567389, -1.7158676 , -0.5843805 ],
        [-0.7664911 , -0.7145203 , -1.089696  ,  0.14649415,
          0.03585422,  0.9916008 ,  0.9384322 ,  0.34755042,
         -0.09592161,  0.76490027, -1.2517685 , -1.5740465 ]]],
      dtype=float32)>, <tf.Tensor: shape=(10, 10), dtype=float32, numpy=
array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 1., 1.],
       [0., 0., 0., 0., 0., 0., 0., 0., 1., 1.]], dtype=float32)>)
(<tf.Tensor: shape=(1, 10, 12), dtype=float32, numpy=
array([[[        nan,         nan,         nan,         nan,
                 nan,         nan,         nan,         nan,
                 nan,         nan,         nan,         nan],
        [        nan,         nan,         nan,         nan,
                 nan,         nan,         nan,         nan,
                 nan,         nan,         nan,         nan],
        [        nan,         nan,         nan,         nan,
                 nan,         nan,         nan,         nan,
                 nan,         nan,         nan,         nan],
        [        nan,         nan,         nan,         nan,
                 nan,         nan,         nan,         nan,
                 nan,         nan,         nan,         nan],
        [        nan,         nan,         nan,         nan,
                 nan,         nan,         nan,         nan,
                 nan,         nan,         nan,         nan],
        [        nan,         nan,         nan,         nan,
                 nan,         nan,         nan,         nan,
                 nan,         nan,         nan,         nan],
        [ 0.7558127 ,  1.5447265 ,  1.6315602 , -0.19868968,
          0.08828261,  0.01711658, -1.8133892 ,  0.12930395,
          0.47128937,  0.08567389, -1.7158676 , -0.5843805 ],
        [-0.7664911 , -0.7145203 , -1.089696  ,  0.14649415,
          0.03585422,  0.9916008 ,  0.9384322 ,  0.34755042,
         -0.09592161,  0.76490027, -1.2517685 , -1.5740465 ],
        [-0.31995258, -0.43669155, -0.28932425, -0.06870204,
         -0.01291991,  1.171546  ,  0.75079876, -0.7693662 ,
          0.05902815,  0.60606545, -1.1038904 , -0.99837613],
        [-0.6687948 ,  0.22192897, -0.02249479, -0.08962449,
          1.2408841 ,  0.119805  , -0.53699344,  1.020805  ,
          0.9610218 ,  0.6133564 , -0.4358486 ,  2.733222  ]]],
      dtype=float32)>, <tf.Tensor: shape=(10, 10), dtype=float32, numpy=
array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 1., 1., 1., 1.],
       [0., 0., 0., 0., 0., 0., 1., 1., 1., 1.],
       [0., 0., 0., 0., 0., 0., 1., 1., 1., 1.],
       [0., 0., 0., 0., 0., 0., 1., 1., 1., 1.]], dtype=float32)>)
(<tf.Tensor: shape=(1, 10, 12), dtype=float32, numpy=
array([[[        nan,         nan,         nan,         nan,
                 nan,         nan,         nan,         nan,
                 nan,         nan,         nan,         nan],
        [        nan,         nan,         nan,         nan,
                 nan,         nan,         nan,         nan,
                 nan,         nan,         nan,         nan],
        [        nan,         nan,         nan,         nan,
                 nan,         nan,         nan,         nan,
                 nan,         nan,         nan,         nan],
        [        nan,         nan,         nan,         nan,
                 nan,         nan,         nan,         nan,
                 nan,         nan,         nan,         nan],
        [ 0.7558127 ,  1.5447265 ,  1.6315602 , -0.19868968,
          0.08828261,  0.01711658, -1.8133892 ,  0.12930395,
          0.47128937,  0.08567389, -1.7158676 , -0.5843805 ],
        [-0.7664911 , -0.7145203 , -1.089696  ,  0.14649415,
          0.03585422,  0.9916008 ,  0.9384322 ,  0.34755042,
         -0.09592161,  0.76490027, -1.2517685 , -1.5740465 ],
        [-0.31995258, -0.43669155, -0.28932425, -0.06870204,
         -0.01291991,  1.171546  ,  0.75079876, -0.7693662 ,
          0.05902815,  0.60606545, -1.1038904 , -0.99837613],
        [-0.6687948 ,  0.22192897, -0.02249479, -0.08962449,
          1.2408841 ,  0.119805  , -0.53699344,  1.020805  ,
          0.9610218 ,  0.6133564 , -0.4358486 ,  2.733222  ],
        [-0.33772066,  0.80799913, -0.00896128,  1.606288  ,
          1.1561627 ,  0.17252289,  0.2451608 ,  1.4633939 ,
         -0.9294784 ,  0.42795137, -0.3016553 , -1.1823792 ],
        [ 0.30927372,  0.3482721 ,  1.0262096 , -0.97228396,
         -0.55333287, -0.7914886 ,  1.0115404 , -0.5656188 ,
          0.30958036, -0.8476673 ,  2.4919312 ,  0.9093976 ]]],
      dtype=float32)>, <tf.Tensor: shape=(10, 10), dtype=float32, numpy=
array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 1., 1., 1., 1., 1.],
       [0., 0., 0., 0., 1., 1., 1., 1., 1., 1.],
       [0., 0., 0., 0., 1., 1., 1., 1., 1., 1.],
       [0., 0., 0., 0., 1., 1., 1., 1., 1., 1.],
       [0., 0., 0., 0., 1., 1., 1., 1., 1., 1.],
       [0., 0., 0., 0., 1., 1., 1., 1., 1., 1.]], dtype=float32)>)
(<tf.Tensor: shape=(1, 10, 12), dtype=float32, numpy=
array([[[        nan,         nan,         nan,         nan,
                 nan,         nan,         nan,         nan,
                 nan,         nan,         nan,         nan],
        [        nan,         nan,         nan,         nan,
                 nan,         nan,         nan,         nan,
                 nan,         nan,         nan,         nan],
        [ 0.7558127 ,  1.5447265 ,  1.6315602 , -0.19868968,
          0.08828261,  0.01711658, -1.8133892 ,  0.12930395,
          0.47128937,  0.08567389, -1.7158676 , -0.5843805 ],
        [-0.7664911 , -0.7145203 , -1.089696  ,  0.14649415,
          0.03585422,  0.9916008 ,  0.9384322 ,  0.34755042,
         -0.09592161,  0.76490027, -1.2517685 , -1.5740465 ],
        [-0.31995258, -0.43669155, -0.28932425, -0.06870204,
         -0.01291991,  1.171546  ,  0.75079876, -0.7693662 ,
          0.05902815,  0.60606545, -1.1038904 , -0.99837613],
        [-0.6687948 ,  0.22192897, -0.02249479, -0.08962449,
          1.2408841 ,  0.119805  , -0.53699344,  1.020805  ,
          0.9610218 ,  0.6133564 , -0.4358486 ,  2.733222  ],
        [-0.33772066,  0.80799913, -0.00896128,  1.606288  ,
          1.1561627 ,  0.17252289,  0.2451608 ,  1.4633939 ,
         -0.9294784 ,  0.42795137, -0.3016553 , -1.1823792 ],
        [ 0.30927372,  0.3482721 ,  1.0262096 , -0.97228396,
         -0.55333287, -0.7914886 ,  1.0115404 , -0.5656188 ,
          0.30958036, -0.8476673 ,  2.4919312 ,  0.9093976 ],
        [-0.44241378, -0.6971805 , -0.37439492,  1.0154608 ,
         -0.34494257,  0.1988212 , -0.9541314 , -0.44339198,
          0.162457  , -0.31033182, -0.34568167,  1.0341203 ],
        [-0.89020306, -0.8646532 ,  0.13348487, -0.6604107 ,
          0.07642484,  1.3407826 ,  0.79119945, -0.7598532 ,
          0.85146165, -0.2791065 , -0.4600736 ,  0.809218  ]]],
      dtype=float32)>, <tf.Tensor: shape=(10, 10), dtype=float32, numpy=
array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 1., 1., 1., 1., 1., 1., 1., 1.],
       [0., 0., 1., 1., 1., 1., 1., 1., 1., 1.],
       [0., 0., 1., 1., 1., 1., 1., 1., 1., 1.],
       [0., 0., 1., 1., 1., 1., 1., 1., 1., 1.],
       [0., 0., 1., 1., 1., 1., 1., 1., 1., 1.],
       [0., 0., 1., 1., 1., 1., 1., 1., 1., 1.],
       [0., 0., 1., 1., 1., 1., 1., 1., 1., 1.],
       [0., 0., 1., 1., 1., 1., 1., 1., 1., 1.]], dtype=float32)>)
(<tf.Tensor: shape=(1, 10, 12), dtype=float32, numpy=
array([[[ 7.5581270e-01,  1.5447265e+00,  1.6315602e+00, -1.9868968e-01,
          8.8282607e-02,  1.7116580e-02, -1.8133892e+00,  1.2930395e-01,
          4.7128937e-01,  8.5673891e-02, -1.7158676e+00, -5.8438051e-01],
        [-7.6649112e-01, -7.1452028e-01, -1.0896960e+00,  1.4649415e-01,
          3.5854220e-02,  9.9160081e-01,  9.3843222e-01,  3.4755042e-01,
         -9.5921606e-02,  7.6490027e-01, -1.2517685e+00, -1.5740465e+00],
        [-3.1995258e-01, -4.3669155e-01, -2.8932425e-01, -6.8702042e-02,
         -1.2919909e-02,  1.1715460e+00,  7.5079876e-01, -7.6936620e-01,
          5.9028149e-02,  6.0606545e-01, -1.1038904e+00, -9.9837613e-01],
        [-6.6879481e-01,  2.2192897e-01, -2.2494787e-02, -8.9624494e-02,
          1.2408841e+00,  1.1980500e-01, -5.3699344e-01,  1.0208050e+00,
          9.6102178e-01,  6.1335641e-01, -4.3584859e-01,  2.7332220e+00],
        [-3.3772066e-01,  8.0799913e-01, -8.9612845e-03,  1.6062880e+00,
          1.1561627e+00,  1.7252289e-01,  2.4516080e-01,  1.4633939e+00,
         -9.2947841e-01,  4.2795137e-01, -3.0165529e-01, -1.1823792e+00],
        [ 3.0927372e-01,  3.4827209e-01,  1.0262096e+00, -9.7228396e-01,
         -5.5333287e-01, -7.9148859e-01,  1.0115404e+00, -5.6561881e-01,
          3.0958036e-01, -8.4766728e-01,  2.4919312e+00,  9.0939760e-01],
        [-4.4241378e-01, -6.9718051e-01, -3.7439492e-01,  1.0154608e+00,
         -3.4494257e-01,  1.9882120e-01, -9.5413142e-01, -4.4339198e-01,
          1.6245700e-01, -3.1033182e-01, -3.4568167e-01,  1.0341203e+00],
        [-8.9020306e-01, -8.6465323e-01,  1.3348487e-01, -6.6041070e-01,
          7.6424837e-02,  1.3407826e+00,  7.9119945e-01, -7.5985318e-01,
          8.5146165e-01, -2.7910650e-01, -4.6007359e-01,  8.0921799e-01],
        [-6.7833281e-01,  4.7877081e-02, -2.0416839e+00, -1.5634586e+00,
         -5.1782840e-01,  5.2898288e-01, -1.4573561e+00,  4.6455118e-01,
         -3.2871577e-01, -1.5697428e+00,  1.4454672e-01,  8.2387424e-01],
        [ 2.5552011e-03,  1.2834518e+00,  4.1382611e-01,  1.6535892e+00,
          7.8654990e-02, -1.2952465e-01,  3.6811054e-01,  1.1675907e+00,
          9.6434945e-01, -4.2399967e-01, -1.3700709e-01, -5.2056974e-01]]],
      dtype=float32)>, None)
(<tf.Tensor: shape=(1, 10, 12), dtype=float32, numpy=
array([[[-3.1995258e-01, -4.3669155e-01, -2.8932425e-01, -6.8702042e-02,
         -1.2919909e-02,  1.1715460e+00,  7.5079876e-01, -7.6936620e-01,
          5.9028149e-02,  6.0606545e-01, -1.1038904e+00, -9.9837613e-01],
        [-6.6879481e-01,  2.2192897e-01, -2.2494787e-02, -8.9624494e-02,
          1.2408841e+00,  1.1980500e-01, -5.3699344e-01,  1.0208050e+00,
          9.6102178e-01,  6.1335641e-01, -4.3584859e-01,  2.7332220e+00],
        [-3.3772066e-01,  8.0799913e-01, -8.9612845e-03,  1.6062880e+00,
          1.1561627e+00,  1.7252289e-01,  2.4516080e-01,  1.4633939e+00,
         -9.2947841e-01,  4.2795137e-01, -3.0165529e-01, -1.1823792e+00],
        [ 3.0927372e-01,  3.4827209e-01,  1.0262096e+00, -9.7228396e-01,
         -5.5333287e-01, -7.9148859e-01,  1.0115404e+00, -5.6561881e-01,
          3.0958036e-01, -8.4766728e-01,  2.4919312e+00,  9.0939760e-01],
        [-4.4241378e-01, -6.9718051e-01, -3.7439492e-01,  1.0154608e+00,
         -3.4494257e-01,  1.9882120e-01, -9.5413142e-01, -4.4339198e-01,
          1.6245700e-01, -3.1033182e-01, -3.4568167e-01,  1.0341203e+00],
        [-8.9020306e-01, -8.6465323e-01,  1.3348487e-01, -6.6041070e-01,
          7.6424837e-02,  1.3407826e+00,  7.9119945e-01, -7.5985318e-01,
          8.5146165e-01, -2.7910650e-01, -4.6007359e-01,  8.0921799e-01],
        [-6.7833281e-01,  4.7877081e-02, -2.0416839e+00, -1.5634586e+00,
         -5.1782840e-01,  5.2898288e-01, -1.4573561e+00,  4.6455118e-01,
         -3.2871577e-01, -1.5697428e+00,  1.4454672e-01,  8.2387424e-01],
        [ 2.5552011e-03,  1.2834518e+00,  4.1382611e-01,  1.6535892e+00,
          7.8654990e-02, -1.2952465e-01,  3.6811054e-01,  1.1675907e+00,
          9.6434945e-01, -4.2399967e-01, -1.3700709e-01, -5.2056974e-01],
        [ 1.3070145e+00, -6.7240512e-01,  1.9308577e+00,  1.7688200e-03,
          3.0533668e-01,  6.5813893e-01,  5.2471739e-01,  2.1659613e+00,
         -8.7725663e-01,  3.5695407e-01, -1.2751107e+00, -7.7276069e-01],
        [-4.3180370e-01, -1.1814500e+00,  2.4167557e-01,  5.7490116e-01,
          5.6998456e-01, -7.4528801e-01, -9.1826969e-01, -7.3694932e-01,
         -1.2400552e+00,  1.6947891e+00, -2.6127639e-01,  7.8419834e-01]]],
      dtype=float32)>, None)
True 


(<tf.Tensor: shape=(1, 10, 12), dtype=float32, numpy=
array([[[           nan,            nan,            nan,            nan,
                    nan,            nan,            nan,            nan,
                    nan,            nan,            nan,            nan],
        [           nan,            nan,            nan,            nan,
                    nan,            nan,            nan,            nan,
                    nan,            nan,            nan,            nan],
        [           nan,            nan,            nan,            nan,
                    nan,            nan,            nan,            nan,
                    nan,            nan,            nan,            nan],
        [           nan,            nan,            nan,            nan,
                    nan,            nan,            nan,            nan,
                    nan,            nan,            nan,            nan],
        [           nan,            nan,            nan,            nan,
                    nan,            nan,            nan,            nan,
                    nan,            nan,            nan,            nan],
        [           nan,            nan,            nan,            nan,
                    nan,            nan,            nan,            nan,
                    nan,            nan,            nan,            nan],
        [           nan,            nan,            nan,            nan,
                    nan,            nan,            nan,            nan,
                    nan,            nan,            nan,            nan],
        [           nan,            nan,            nan,            nan,
                    nan,            nan,            nan,            nan,
                    nan,            nan,            nan,            nan],
        [ 1.3070145e+00, -6.7240512e-01,  1.9308577e+00,  1.7688200e-03,
          3.0533668e-01,  6.5813893e-01,  5.2471739e-01,  2.1659613e+00,
         -8.7725663e-01,  3.5695407e-01, -1.2751107e+00, -7.7276069e-01],
        [-4.3180370e-01, -1.1814500e+00,  2.4167557e-01,  5.7490116e-01,
          5.6998456e-01, -7.4528801e-01, -9.1826969e-01, -7.3694932e-01,
         -1.2400552e+00,  1.6947891e+00, -2.6127639e-01,  7.8419834e-01]]],
      dtype=float32)>, <tf.Tensor: shape=(10, 10), dtype=float32, numpy=
array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 1., 1.],
       [0., 0., 0., 0., 0., 0., 0., 0., 1., 1.]], dtype=float32)>)
tf.Tensor(False, shape=(), dtype=bool) 

附带说明一下,使用 tf.queue.FIFOQueue 真的很慢。

预处理层

 class FIFOLayer(tf.keras.layers.Layer):
    def __init__(self, window_size, **kwargs):
        super(FIFOLayer, self).__init__(**kwargs)

        self.window_size = window_size

    def build(self, input_shape):
        super(FIFOLayer, self).build(input_shape)

        # init FIFO
        self.queue = self.add_weight(
            shape=(self.window_size, input_shape[-1]), 
            initializer=tf.keras.initializers.Constant(value=np.nan),
            trainable=False
        )

    def call(self, inputs, training):
        if inputs.shape.rank == 2:
            assert self.queue.shape[-1] == inputs.shape[-1]
            timesteps = tf.shape(inputs)[0]
        elif inputs.shape.rank == 1:
            inputs = tf.reshape(inputs, (1, self.queue.shape[-1]))
            timesteps = 1
        else:
            raise ValueError("The rank of inputs is not 2 or 1 !")

        self.queue.assign(tf.concat(
            [
                self.queue[timesteps:self.window_size],
                inputs
            ],
            axis=0,
        ))

        # generate mask
        attention_mask = tf.cast(
            tf.math.reduce_all(
                tf.math.logical_not(tf.math.is_nan(self.queue)), axis=-1
            ),
            dtype=tf.float32,
        )
        attention_mask = tf.matmul(
            attention_mask[..., tf.newaxis],
            attention_mask[..., tf.newaxis],
            transpose_b=True,
        )
        return self.queue, attention_mask

    def clear(self):
        self.queue.assign(tf.fill((self.window_size, self.queue.shape[-1]), np.nan))

测试

preprocessing_layer = FIFOLayer(
    window_size = 10,
)
dataset = tf.data.Dataset.from_tensor_slices((
    np.random.normal(size=(64, 10)),
    np.ones((64,))
))
dataset = dataset.map(lambda x, y: (preprocessing_layer(x), y))

for features in dataset.take(10):
    print(features)

preprocessing_layer.clear()

for features in dataset.take(10):
    print(features)