TF map_fn 或 while_loop 用于不同形状的张量列表

TF map_fn or while_loop for a list of tensors of different shape

我想处理不同形状的张量序列(列表)并输出另一个张量列表。考虑在每个时间戳上具有不同隐藏状态大小的 RNN。像

输入:[tf.ones((1, 2, 2)), tf.ones((2, 2, 3)), tf.ones((3, 2, 1 ))]

输出:[tf.zeros((1, 2, 4)), tf.zeros((4, 2, 6)), tf.zeros((6, 2, 1 ))]

我无法将输入(或输出)堆叠到单个张量中,因为它们都有不同的形状,因此我无法使用 tf.map_fn 来完成任务。现在,我使用 python for 循环,但它似乎不是最优的。

我能做些什么更好的事情吗?

您可以使用 tf.while_loop 重复执行任意 TensorFlow 操作,直到出现某种停止条件。停止条件本身指定为操作。

请注意 tf.while_loop 应谨慎使用,因为默认情况下它的迭代将 运行 并行。例如,如果循环体递增一个tf.Variable,那么你必须使用control dependencies来确保迭代运行顺序。

但是, 您提到您有一个带有 Python 循环的有效实现。如果可能,对循环使用 Python 通常是最有效的解决方案。当您在 Python 中构建循环时,您会为循环中的每个迭代创建单独的操作。这让 TensorFlow 在 graph-building 决定如何为每个操作分配计算资源。例如,如果事先知道迭代次数,则内存需求和并行化可能性更容易预测。

因此,tf.while_looptf.map_fn 最常用于在 graph-building 时不知道停止条件的情况。

如果有固定但非常大的迭代次数,您可能仍想使用 tf.while_loop 而不是 Python 循环,因为每个操作的内存成本不小。