在 Tensorflow 中将自定义梯度定义为 class 方法
Defining custom gradient as a class method in Tensorflow
我需要定义一个自定义渐变的方法如下:
class CustGradClass:
def __init__(self):
pass
@tf.custom_gradient
def f(self,x):
fx = x
def grad(dy):
return dy * 1
return fx, grad
我收到以下错误:
ValueError: Attempt to convert a value (<main.CustGradClass object at 0x12ed91710>) with an unsupported type () to a Tensor.
原因是自定义梯度接受函数 f(*x),其中 x 是张量序列。传递的第一个参数是对象本身,即 self。
f: function f(*x) that returns a tuple (y, grad_fn) where:
x is a sequence of Tensor inputs to the function.
y is a Tensor or sequence of Tensor outputs of applying TensorFlow operations in f to x.
grad_fn is a function with the signature g(*grad_ys)
如何让它发挥作用?我需要继承一些 python tensorflow class 吗?
我正在使用 tf 版本 1.12.0 和 eager 模式。
在您的示例中,您没有使用任何成员变量,因此您可以将该方法设为静态方法。如果您使用成员变量,则从成员函数调用静态方法并将成员变量作为参数传递。
class CustGradClass:
def __init__(self):
self.some_var = ...
@staticmethod
@tf.custom_gradient
def _f(x):
fx = x
def grad(dy):
return dy * 1
return fx, grad
def f(self):
return CustGradClass._f(self.some_var)
这是一种可能的简单解决方法:
import tensorflow as tf
class CustGradClass:
def __init__(self):
self.f = tf.custom_gradient(lambda x: CustGradClass._f(self, x))
@staticmethod
def _f(self, x):
fx = x * 1
def grad(dy):
return dy * 1
return fx, grad
with tf.Graph().as_default(), tf.Session() as sess:
x = tf.constant(1.0)
c = CustGradClass()
y = c.f(x)
print(tf.gradients(y, x))
# [<tf.Tensor 'gradients/IdentityN_grad/mul:0' shape=() dtype=float32>]
编辑:
如果你想在不同的 类 上多次这样做,或者只是想要一个更可重用的解决方案,你可以使用这样的装饰器,例如:
import functools
import tensorflow as tf
def tf_custom_gradient_method(f):
@functools.wraps(f)
def wrapped(self, *args, **kwargs):
if not hasattr(self, '_tf_custom_gradient_wrappers'):
self._tf_custom_gradient_wrappers = {}
if f not in self._tf_custom_gradient_wrappers:
self._tf_custom_gradient_wrappers[f] = tf.custom_gradient(lambda *a, **kw: f(self, *a, **kw))
return self._tf_custom_gradient_wrappers[f](*args, **kwargs)
return wrapped
那么你可以这样做:
class CustGradClass:
def __init__(self):
pass
@tf_custom_gradient_method
def f(self, x):
fx = x * 1
def grad(dy):
return dy * 1
return fx, grad
@tf_custom_gradient_method
def f2(self, x):
fx = x * 2
def grad(dy):
return dy * 2
return fx, grad
我需要定义一个自定义渐变的方法如下:
class CustGradClass:
def __init__(self):
pass
@tf.custom_gradient
def f(self,x):
fx = x
def grad(dy):
return dy * 1
return fx, grad
我收到以下错误:
ValueError: Attempt to convert a value (<main.CustGradClass object at 0x12ed91710>) with an unsupported type () to a Tensor.
原因是自定义梯度接受函数 f(*x),其中 x 是张量序列。传递的第一个参数是对象本身,即 self。
f: function f(*x) that returns a tuple (y, grad_fn) where:
x is a sequence of Tensor inputs to the function. y is a Tensor or sequence of Tensor outputs of applying TensorFlow operations in f to x. grad_fn is a function with the signature g(*grad_ys)
如何让它发挥作用?我需要继承一些 python tensorflow class 吗?
我正在使用 tf 版本 1.12.0 和 eager 模式。
在您的示例中,您没有使用任何成员变量,因此您可以将该方法设为静态方法。如果您使用成员变量,则从成员函数调用静态方法并将成员变量作为参数传递。
class CustGradClass:
def __init__(self):
self.some_var = ...
@staticmethod
@tf.custom_gradient
def _f(x):
fx = x
def grad(dy):
return dy * 1
return fx, grad
def f(self):
return CustGradClass._f(self.some_var)
这是一种可能的简单解决方法:
import tensorflow as tf
class CustGradClass:
def __init__(self):
self.f = tf.custom_gradient(lambda x: CustGradClass._f(self, x))
@staticmethod
def _f(self, x):
fx = x * 1
def grad(dy):
return dy * 1
return fx, grad
with tf.Graph().as_default(), tf.Session() as sess:
x = tf.constant(1.0)
c = CustGradClass()
y = c.f(x)
print(tf.gradients(y, x))
# [<tf.Tensor 'gradients/IdentityN_grad/mul:0' shape=() dtype=float32>]
编辑:
如果你想在不同的 类 上多次这样做,或者只是想要一个更可重用的解决方案,你可以使用这样的装饰器,例如:
import functools
import tensorflow as tf
def tf_custom_gradient_method(f):
@functools.wraps(f)
def wrapped(self, *args, **kwargs):
if not hasattr(self, '_tf_custom_gradient_wrappers'):
self._tf_custom_gradient_wrappers = {}
if f not in self._tf_custom_gradient_wrappers:
self._tf_custom_gradient_wrappers[f] = tf.custom_gradient(lambda *a, **kw: f(self, *a, **kw))
return self._tf_custom_gradient_wrappers[f](*args, **kwargs)
return wrapped
那么你可以这样做:
class CustGradClass:
def __init__(self):
pass
@tf_custom_gradient_method
def f(self, x):
fx = x * 1
def grad(dy):
return dy * 1
return fx, grad
@tf_custom_gradient_method
def f2(self, x):
fx = x * 2
def grad(dy):
return dy * 2
return fx, grad