如何在 tf.function(图形模式)中展平梯度(张量列表)

How to flatten a gradient (list of tensors) in tf.function (graph mode)

我想用梯度做一些线性代数(例如tf.matmul)。默认情况下,梯度以张量列表的形式返回,其中张量可能具有不同的形状。我的解决方案是将梯度重塑为单个向量。这在 eager 模式下有效,但现在我想使用 tf.function 编译我的代码。似乎没有办法编写一个可以'flatten'图形模式梯度的函数(tf.function)。

grad = [tf.ones((2,10)), tf.ones((3,))]  # an example of what a gradient from tape.gradient can look like

# this works for flattening the gradient in eager mode only
def flatten_grad(grad):
    return tf.concat([tf.reshape(grad[i], tf.math.reduce_prod(tf.shape(grad[i]))) for i in range(len(grad))], 0)

我试过这样转换它,但它也不适用于 tf.function。

@tf.function
def flatten_grad1(grad):
    temp = [None]*len(grad)
    for i in tf.range(len(grad)):
        i = tf.cast(i, tf.int32)
        temp[i] = tf.reshape(grad[i], tf.math.reduce_prod(tf.shape(grad[i])))
    return tf.concat(temp, 0)

我试过 TensorArrays,但还是不行。

@tf.function
def flatten_grad2(grad):
    temp = tf.TensorArray(tf.float32, size=len(grad), infer_shape=False)
    for i in tf.range(len(grad)):
        i = tf.cast(i, tf.int32)
        temp = temp.write(i, tf.reshape(grad[i], tf.math.reduce_prod(tf.shape(grad[i]))))
    return temp.concat()

嗨,我认为最大的问题是 python 不鼓励计算循环的循环。

下面是一个示例,说明如何使用 tf 函数使梯度变量变平看起来有点奇怪,通常应该与批处理的形状一致

import tensorflow as tf
import numpy as np

@tf.function
def flatten(arr):
     dim = tf.math.reduce_prod(tf.shape(arr)[1:])
     return tf.reshape(arr, [-1, dim])

grad = tf.Variable(np.random.randn(100, 10, 10, 3))

flatten_grad = flatten(grad)

也许您可以尝试直接迭代您的 list 张量,而不是通过索引获取单个张量:

import tensorflow as tf

grad = [tf.ones((2,10)), tf.ones((3,))]  # an example of what a gradient from tape.gradient can look like

@tf.function
def flatten_grad1(grad):
    temp = [None]*len(grad)
    for i, g in enumerate(grad):
        temp[i] = tf.reshape(g, (tf.math.reduce_prod(tf.shape(g)), ))
    return tf.concat(temp, axis=0)
print(flatten_grad1(grad))
tf.Tensor([1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.], shape=(23,), dtype=float32)

tf.TensorArray:

@tf.function
def flatten_grad2(grad):
    temp = tf.TensorArray(tf.float32, size=0, dynamic_size=True, infer_shape=False)
    for g in grad:
        temp = temp.write(temp.size(), tf.reshape(g, (tf.math.reduce_prod(tf.shape(g)), )))
    return temp.concat()

print(flatten_grad2(grad))