在 Tensorflow 2.0 中使用 gather 或 boolean_mask 后,张量维度变为 None

Tensor dimension becomes None after using gather or boolean_mask in Tensorflow 2.0

出于某种原因,在 TF 2 中使用 gather 时我得到了不同的张量维度:

  1. 当我使用张量作为索引向量时,第一维变为None
  2. 第一个维度变为 len(indices)(应该如此),其中 'indices' 是常规 Python 列表

这仅在急切模式下发生(例如,在自定义损失函数内)

(使用 boolean_mask 时也是如此)

编辑:以下代码重现了 TF 2.7.0 和 Python 3.8.10

的问题
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import Model
from tensorflow.keras.layers import Input, Dense, Reshape
from tensorflow.keras.datasets import mnist

def cutsom_gan_loss_env(model):
   def custom_loss(y_true,y_pred):

    ff = tf.where([True, True, False , False])[:, 0]
    with tf.GradientTape(persistent=True) as tape:
         tf.print(tf.gather(y_true, [0, 1], axis=0).shape)
         tf.print(tf.gather(y_true, ff, axis=0).shape)
         tape.watch(y_true)
         yy = model(y_true)
         d_yy = tape.gradient(yy,y_true)
         des_loss = tf.reduce_mean(d_yy)

    return des_loss

return custom_loss


def main_():
   n_hidden_units = 5
   num_lay = 3
   kernel_init = keras.initializers.RandomUniform(-0.1, 0.1)
   (x_train, y_train), _ = mnist.load_data()
   x_train = tf.cast(x_train,tf.float32)/255.
   inputs = Input(x_train.shape[1:])
   x = Dense(n_hidden_units,kernel_initializer=kernel_init,  activation='sigmoid' )(inputs)
   for _ in range(num_lay):
       x = Dense(n_hidden_units,kernel_initializer=kernel_init, activation='sigmoid', )(x)

   outputs =Reshape(x_train.shape[1:])(Dense(x_train.shape[1], kernel_initializer=kernel_init, activation='softmax')(x))
   model = Model(inputs=inputs, outputs=outputs)
   model.summary()
   optimizer1 = keras.optimizers.Adam(beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=True)
   model.compile(loss=cutsom_gan_loss_env(model), optimizer=optimizer1, metrics=None)
   model.fit(x_train,  x_train , batch_size=1000, epochs=1, shuffle=False)


if __name__=='__main__':
    main_()

这不是错误,而是tensor.shapetf.shape之间的区别。后者将在 tf.gather.

等操作后为您提供张量的动态形状

变化:

tf.print(tf.gather(y_true, [0, 1], axis=0).shape)
tf.print(tf.gather(y_true, ff, axis=0).shape)

收件人:

tf.print(tf.shape(tf.gather(y_true, [0, 1], axis=0)))
tf.print(tf.shape(tf.gather(y_true, ff, axis=0)))

并且张量将在 model.fit 期间使用 tf.shape 正确评估。另请阅读此 以获得更好的理解。