Jax、jit 和动态形状:Tensorflow 的回归?

Jax, jit and dynamic shapes: a regression from Tensorflow?

documentation for JAX 说,

Not all JAX code can be JIT compiled, as it requires array shapes to be static & known at compile time.

现在我有点惊讶,因为 tensorflow 有像 tf.boolean_mask 这样的操作,它可以做 JAX 在编译时似乎无法做的事情。

  1. 为什么 Tensorflow 会出现这种回归?我假设底层 XLA 表示在两个框架之间共享,但我可能错了。我不记得 Tensorflow 在动态形状方面遇到过麻烦,tf.boolean_mask 等函数一直存在。
  2. 我们能否期待这种差距在未来缩小?如果不是,为什么无法在 JAX' jit 中执行 Tensorflow(以及其他)启用的功能?

编辑

梯度通过tf.boolean_mask(显然不是在mask值上,它是离散的);这里的例子使用值未知的 TF1 样式图表,因此 TF 不能依赖它们:

import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

x1 = tf.placeholder(tf.float32, (3,))
x2 = tf.placeholder(tf.float32, (3,))
y = tf.boolean_mask(x1, x2 > 0)
print(y.shape)  # prints "(?,)"
dydx1, dydx2 = tf.gradients(y, [x1, x2])
assert dydx1 is not None and dydx2 is None

我认为 JAX 在这方面并不比 TensorFlow 更无能。在 JAX 中没有什么禁止您这样做:

new_array = my_array[mask]

但是,mask 应该是索引(整数)而不是布尔值。这样,JAX 就知道 new_array 的形状(与 mask 相同)。从这个意义上说,我很确定 tf.boolean_mask 是不可微分的,即如果您尝试在某个点计算它的梯度,它会引发错误。

更一般地说,如果您需要屏蔽数组,无论您使用的是什么库,都有两种方法:

  1. 如果您事先知道需要选择哪些索引并且需要提供这些索引以便库可以在编译前计算形状;
  2. 如果您无法定义这些索引,无论出于何种原因,那么您需要设计您的代码以避免防止填充影响您的结果。

每种情况的示例

  1. 假设您正在 JAX 中编写一个简单的嵌入层。 input 是一批 token 索引对应几个句子。为了得到对应于这些索引的词嵌入,我将简单地写成word_embeddings = embeddings[input]。由于我事先不知道句子的长度,所以我需要预先将所有的标记序列填充到相同的长度,这样input的形状就是(number_of_sentences, sentence_max_length)。现在,每次此形状发生变化时,JAX 都会编译屏蔽操作。为了最小化编译次数,可以提供相同数量的句子(也称为批量大小),并且可以将 sentence_max_length 设置为整个语料库中的最大句子长度。这样,训练时就只有一次编译。当然,你需要在word_embeddings中预留一行对应pad索引。但是,掩蔽仍然有效。

  2. 稍后在模型中,假设您想将每个句子的每个单词表示为句子中所有其他单词的加权平均值(如自我注意机制)。为整个批次并行计算权重,并存储在维度 (number_of_sentences, sentence_max_length, sentence_max_length) 的矩阵 A 中。使用公式 A @ word_embeddings 计算加权平均值。现在,您需要确保 pad 标记不会影响之前的公式。为此,您可以将 A 中对应于焊盘索引的条目置零,以消除它们对平均的影响。如果 pad 令牌索引为 0,你会这样做:

    mask = jnp.array(input > 0, dtype=jnp.float32)
    A = A * mask[:, jnp.newaxis, :]
    weighted_mean = A @ word_embeddings 

所以这里我们使用了一个布尔掩码,但掩码在某种程度上是可微的,因为我们将掩码与另一个矩阵相乘,而不是将其用作索引。请注意,我们应该以相同的方式删除也对应于 pad 标记的 weighted_mean 行。

目前,您不能如此处所述

这不是 JAX jit 与 TensorFlow 的限制,而是 XLA 或两者编译方式的限制。

JAX 只使用 XLA 来编译函数。 XLA 需要知道 静态形状。这是 XLA 中固有的设计选择

TensorFlow 使用 function:这会创建一个图形,其形状可能不是静态已知的。这不如使用 XLA 有效,但仍然可以。但是,tf.function 提供了一个选项 jit_compile,它将使用 XLA 编译函数内部的图形。虽然这通常提供了不错的加速(免费),但它有一些限制:形状需要静态已知(惊喜,惊喜,...)

这总体上并不是太令人惊讶的行为:计算机中的计算通常更快(假设有一个不错的优化器经过它)以前已知的越多参数越多(内存布局,...) 可以优化调度。知道的越少,代码越慢(在这一端是正常的Python)。