Tensorflow,如何实现排序层

Tensorflow, how to implement sorting layer

我正在尝试在 keras 中创建一个层,该层采用平坦张量 x(其中没有零值并且形状 = (batch_size, units)) 乘以 a mask(形状相同),它将按照屏蔽值在输出中首先放置的方式对其进行排序(元素值的顺序无关紧要)。为清楚起见,这里有一个示例 (batch_size = 1, units = 8):

看似简单,问题是我找不到好的解决办法。任何代码或想法都会受到赞赏。

我目前的代码如下,如果你知道更有效的方法请告诉我。

class Sort(keras.layers.Layer):
  
  def call(self, inputs):
    x = inputs.numpy()
    nonx, nony = x.nonzero()    # idxs of nonzero elements
    zero = [np.where(x == 0)[0][0], np.where(x == 0)[1][0]]   # idx of first zero
    x_shape = tf.shape(inputs)
    result = np.zeros((x_shape[0], x_shape[1], 2), dtype = 'int')   # mapping matrix

    result[:, :, 0] += zero[0]
    result[:, :, 1] += zero[1]
    p = np.zeros((x_shape[0]), dtype = 'int')
    for i, j in zip(nonx, nony):
      result[i, p[i]] = [i, j]
      p[i] += 1

    y = tf.gather_nd(inputs, result)
    return y