如何在 tf.function(图形模式)中展平梯度(张量列表)
How to flatten a gradient (list of tensors) in tf.function (graph mode)
我想用梯度做一些线性代数(例如tf.matmul)。默认情况下,梯度以张量列表的形式返回,其中张量可能具有不同的形状。我的解决方案是将梯度重塑为单个向量。这在 eager 模式下有效,但现在我想使用 tf.function 编译我的代码。似乎没有办法编写一个可以'flatten'图形模式梯度的函数(tf.function)。
grad = [tf.ones((2,10)), tf.ones((3,))] # an example of what a gradient from tape.gradient can look like
# this works for flattening the gradient in eager mode only
def flatten_grad(grad):
return tf.concat([tf.reshape(grad[i], tf.math.reduce_prod(tf.shape(grad[i]))) for i in range(len(grad))], 0)
我试过这样转换它,但它也不适用于 tf.function。
@tf.function
def flatten_grad1(grad):
temp = [None]*len(grad)
for i in tf.range(len(grad)):
i = tf.cast(i, tf.int32)
temp[i] = tf.reshape(grad[i], tf.math.reduce_prod(tf.shape(grad[i])))
return tf.concat(temp, 0)
我试过 TensorArrays,但还是不行。
@tf.function
def flatten_grad2(grad):
temp = tf.TensorArray(tf.float32, size=len(grad), infer_shape=False)
for i in tf.range(len(grad)):
i = tf.cast(i, tf.int32)
temp = temp.write(i, tf.reshape(grad[i], tf.math.reduce_prod(tf.shape(grad[i]))))
return temp.concat()
嗨,我认为最大的问题是 python 不鼓励计算循环的循环。
下面是一个示例,说明如何使用 tf 函数使梯度变量变平看起来有点奇怪,通常应该与批处理的形状一致
import tensorflow as tf
import numpy as np
@tf.function
def flatten(arr):
dim = tf.math.reduce_prod(tf.shape(arr)[1:])
return tf.reshape(arr, [-1, dim])
grad = tf.Variable(np.random.randn(100, 10, 10, 3))
flatten_grad = flatten(grad)
也许您可以尝试直接迭代您的 list
张量,而不是通过索引获取单个张量:
import tensorflow as tf
grad = [tf.ones((2,10)), tf.ones((3,))] # an example of what a gradient from tape.gradient can look like
@tf.function
def flatten_grad1(grad):
temp = [None]*len(grad)
for i, g in enumerate(grad):
temp[i] = tf.reshape(g, (tf.math.reduce_prod(tf.shape(g)), ))
return tf.concat(temp, axis=0)
print(flatten_grad1(grad))
tf.Tensor([1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.], shape=(23,), dtype=float32)
与tf.TensorArray
:
@tf.function
def flatten_grad2(grad):
temp = tf.TensorArray(tf.float32, size=0, dynamic_size=True, infer_shape=False)
for g in grad:
temp = temp.write(temp.size(), tf.reshape(g, (tf.math.reduce_prod(tf.shape(g)), )))
return temp.concat()
print(flatten_grad2(grad))
我想用梯度做一些线性代数(例如tf.matmul)。默认情况下,梯度以张量列表的形式返回,其中张量可能具有不同的形状。我的解决方案是将梯度重塑为单个向量。这在 eager 模式下有效,但现在我想使用 tf.function 编译我的代码。似乎没有办法编写一个可以'flatten'图形模式梯度的函数(tf.function)。
grad = [tf.ones((2,10)), tf.ones((3,))] # an example of what a gradient from tape.gradient can look like
# this works for flattening the gradient in eager mode only
def flatten_grad(grad):
return tf.concat([tf.reshape(grad[i], tf.math.reduce_prod(tf.shape(grad[i]))) for i in range(len(grad))], 0)
我试过这样转换它,但它也不适用于 tf.function。
@tf.function
def flatten_grad1(grad):
temp = [None]*len(grad)
for i in tf.range(len(grad)):
i = tf.cast(i, tf.int32)
temp[i] = tf.reshape(grad[i], tf.math.reduce_prod(tf.shape(grad[i])))
return tf.concat(temp, 0)
我试过 TensorArrays,但还是不行。
@tf.function
def flatten_grad2(grad):
temp = tf.TensorArray(tf.float32, size=len(grad), infer_shape=False)
for i in tf.range(len(grad)):
i = tf.cast(i, tf.int32)
temp = temp.write(i, tf.reshape(grad[i], tf.math.reduce_prod(tf.shape(grad[i]))))
return temp.concat()
嗨,我认为最大的问题是 python 不鼓励计算循环的循环。
下面是一个示例,说明如何使用 tf 函数使梯度变量变平看起来有点奇怪,通常应该与批处理的形状一致
import tensorflow as tf
import numpy as np
@tf.function
def flatten(arr):
dim = tf.math.reduce_prod(tf.shape(arr)[1:])
return tf.reshape(arr, [-1, dim])
grad = tf.Variable(np.random.randn(100, 10, 10, 3))
flatten_grad = flatten(grad)
也许您可以尝试直接迭代您的 list
张量,而不是通过索引获取单个张量:
import tensorflow as tf
grad = [tf.ones((2,10)), tf.ones((3,))] # an example of what a gradient from tape.gradient can look like
@tf.function
def flatten_grad1(grad):
temp = [None]*len(grad)
for i, g in enumerate(grad):
temp[i] = tf.reshape(g, (tf.math.reduce_prod(tf.shape(g)), ))
return tf.concat(temp, axis=0)
print(flatten_grad1(grad))
tf.Tensor([1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.], shape=(23,), dtype=float32)
与tf.TensorArray
:
@tf.function
def flatten_grad2(grad):
temp = tf.TensorArray(tf.float32, size=0, dynamic_size=True, infer_shape=False)
for g in grad:
temp = temp.write(temp.size(), tf.reshape(g, (tf.math.reduce_prod(tf.shape(g)), )))
return temp.concat()
print(flatten_grad2(grad))