有没有办法在 Tensorflow 自定义层(在 TPU 上)动态复制张量 N 次?

Is there a way for dynamic N-times replication of a tensor in Tensorflow custom layer (on TPU)?

我正在尝试解决一个非常简单的任务(我认为它是),即在 TPU 上的自定义层中复制张量。

我的输入是 2 张量,形状为 A=(BS, H, n, C)B = (BS, n, W, C) ,其中 n 在我的例子中可以是 (1, 3, 5, 7),但也应该与其他数字一起使用。

我的任务是重复张量 A 和 B 来塑造 (BS, H, W, C) 并求和它们为输出。如果 H(或 W)总是可以被 n 整除,那将很容易,但事实并非如此。因此 A 的每个切片 (BS, H, 1, C) 的重复次数会有所不同。因此使用以下伪代码计算输出:

for i in range(W):
    A1[BS, H, i, C] = A[BS, H, floor(n*i/W), C]

我尝试以多种方式实现它:

class StripPoolingCombine(tf.keras.layers.Layer):
    def __init__(self, n=1):
        super(StripPoolingCombine, self).__init__()
        self.n = n

    def call(self, v, h, training=False):
        H, W = v.shape[1], h.shape[2]

        v_repeats = tf.unique_with_counts(tf.math.floor(tf.range(W) * self.n / W))[-1]
        h_repeats = tf.unique_with_counts(tf.math.floor(tf.range(H) * self.n / H))[-1]

        v = tf.repeat(v, repeats=v_repeats, axis=2)
        h = tf.repeat(h, repeats=h_repeats, axis=1)

        return Add()([v, h])

或者将 unique_with_counts 替换为以下逻辑:

  1. tf.math.bincount(tf.cast(tf.math.floor(tf.range(W) * self.n / W), dtype=tf.int32)
  2. 使用即兴公式:
f = tf.cast(tf.math.ceil(W / self.n), dtype=tf.int32)
s = tf.cast(tf.math.floor(W / self.n), dtype=tf.int32)
b = tf.cast(f!=s, dtype=tf.int32)
r = W - f - s * (self.n - 1)

x1 = s * tf.ones(self.n-1, dtype=tf.int32)
x2 = (1 - tf.range(r*2) % 2) * b
x2 = tf.pad(x2, paddings=[[0, self.n-r*2-1]])
x3 = tf.concat([[f], tf.add(x1, x2)], axis=0)

但是正如在 Available TensorFlow Ops 中看到的那样,对于 TPU,它不支持动态 tf.rangetf.unique_with_countstf.math.bincount,我的实现都会导致错误在构建模型并调用 model.fit()model.predict() 时。然而,我仍然希望 tensorflow 提供了一些方法来以适合我的任务的方式处理动态形状,并且我不会为这样一个微不足道的问题重写整个 Ops 模块。请帮忙!

完整的可重现示例(使用 Colab TPU):

import tensorflow as tf

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Add


try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print(f'Running on TPU: {tpu.master()}')
except ValueError:
    print('Could not connect to TPU')
    tpu = None

if tpu:
    try:
        print('Initializing  TPU...')
        tf.config.experimental_connect_to_cluster(tpu)
        tf.tpu.experimental.initialize_tpu_system(tpu)
        strategy = tf.distribute.TPUStrategy(tpu)
        print('TPU initialized!')
    except Exception:
        print('Failed to initialize TPU')
        

# class StripPoolingCombine(tf.keras.layers.Layer):
#     def __init__(self, n=1):
#         super(StripPoolingCombine, self).__init__()
#         self.n = n

#     def call(self, v, h, training=False):
#         H, W = v.shape[1], h.shape[2]

#         v_repeats = tf.unique_with_counts(tf.math.floor(tf.range(W) * self.n / W))[-1]
#         h_repeats = tf.unique_with_counts(tf.math.floor(tf.range(H) * self.n / H))[-1]

#         v = tf.repeat(v, repeats=v_repeats, axis=2)
#         h = tf.repeat(h, repeats=h_repeats, axis=1)

#         return Add()([v, h])


class StripPoolingCombine(tf.keras.layers.Layer):
    def __init__(self, n=1):
        super(StripPoolingCombine, self).__init__()
        self.n = n

    def call(self, v, h, training=False):
        H, W = tf.shape(v)[1], tf.shape(h)[2]

        f = tf.cast(tf.math.ceil(W / self.n), dtype=tf.int32)
        s = tf.cast(tf.math.floor(W / self.n), dtype=tf.int32)
        b = tf.cast(f!=s, dtype=tf.int32)
        r = W - f - s * (self.n - 1)

        x1 = s * tf.ones(self.n-1, dtype=tf.int32)
        x2 = (1 - tf.range(r*2) % 2) * b
        x2 = tf.pad(x2, paddings=[[0, self.n-r*2-1]])
        x3 = tf.concat([[f], tf.add(x1, x2)], axis=0)

        v = tf.repeat(v, repeats=x3, axis=2)
        h = tf.repeat(h, repeats=x3, axis=1)

        output = tf.add(v, h)

        return output


def build_model(n=7):
    v = Input(shape=(256, n, 3))
    h = Input(shape=(n, 256, 3))
    outputs = StripPoolingCombine()(v, h)

    model = Model(inputs=[v, h], outputs=outputs)

    return model


tf.keras.backend.clear_session()
with strategy.scope():
    optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4, beta_1=0.9, beta_2=0.999)

    model = build_model()
    model.compile(optimizer=optimizer, loss='mean_squared_error')


rng_1 = tf.random.uniform([1, 256, 7, 3])
rng_2 = tf.random.uniform([1, 7, 256, 3])

model.predict([rng_1, rng_2])

使用tf.gather:

def call(self, v, h, training=False):
    def out(A, H, axis):
      r = tf.range(H)
      inds = tf.floor(self.n * r / H)
      inds = tf.cast(inds, tf.int32)
      return tf.gather(A, inds, axis=axis)
    
    H, W = tf.shape(v)[1], tf.shape(h)[2]
    v = out(v, W, 2)
    h = out(h, H, 1)

    output = tf.add(v, h)

    return output