如何在 google-jax 中使用渐变卷积?
How to use grad convolution in google-jax?
感谢阅读我的问题!
我刚开始学习 Jax 中的自定义 grad 函数,我发现 JAX 定义自定义函数的方法非常优雅。
有一件事让我很烦恼。
我创建了一个包装器来使松散卷积看起来像 PyTorch conv2d。
from jax import numpy as jnp
from jax.random import PRNGKey, normal
from jax import lax
from torch.nn.modules.utils import _ntuple
import jax
from jax.nn.initializers import normal
from jax import grad
torch_dims = {0: ('NC', 'OI', 'NC'), 1: ('NCH', 'OIH', 'NCH'), 2: ('NCHW', 'OIHW', 'NCHW'), 3: ('NCHWD', 'OIHWD', 'NCHWD')}
def conv(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
n = len(input.shape) - 2
if type(stride) == int:
stride = _ntuple(n)(stride)
if type(padding) == int:
padding = [(i, i) for i in _ntuple(n)(padding)]
if type(dilation) == int:
dilation = _ntuple(n)(dilation)
return lax.conv_general_dilated(lhs=input, rhs=weight, window_strides=stride, padding=padding, lhs_dilation=dilation, rhs_dilation=None, dimension_numbers=torch_dims[n], feature_group_count=1, batch_group_count=1, precision=None, preferred_element_type=None)
问题是我找不到使用它的 grad 函数的方法:
init = normal()
rng = PRNGKey(42)
x = init(rng, [128, 3, 224, 224])
k = init(rng, [64, 3, 3, 3])
y = conv(x, k)
grad(conv)(y, k)
这就是我得到的。
ValueError: conv_general_dilated lhs feature dimension size divided by feature_group_count must equal the rhs input feature dimension size, but 64 // 1 != 3.
请帮忙!
当我 运行 您的代码使用最新版本的 jax 和 jaxlib (jax==0.2.22
;jaxlib==0.1.72
) 时,我看到以下错误:
TypeError: Gradient only defined for scalar-output functions. Output had shape: (128, 64, 222, 222).
如果我创建一个使用 conv
的标量输出函数,梯度似乎有效:
result = grad(lambda x, k: conv(x, k).sum())(x, k)
print(result.shape)
# (128, 3, 224, 224)
如果您使用的是旧版本的 JAX,您可以尝试更新到更新的版本——您看到的错误可能是由于已经修复的错误造成的。
感谢阅读我的问题!
我刚开始学习 Jax 中的自定义 grad 函数,我发现 JAX 定义自定义函数的方法非常优雅。
有一件事让我很烦恼。
我创建了一个包装器来使松散卷积看起来像 PyTorch conv2d。
from jax import numpy as jnp
from jax.random import PRNGKey, normal
from jax import lax
from torch.nn.modules.utils import _ntuple
import jax
from jax.nn.initializers import normal
from jax import grad
torch_dims = {0: ('NC', 'OI', 'NC'), 1: ('NCH', 'OIH', 'NCH'), 2: ('NCHW', 'OIHW', 'NCHW'), 3: ('NCHWD', 'OIHWD', 'NCHWD')}
def conv(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
n = len(input.shape) - 2
if type(stride) == int:
stride = _ntuple(n)(stride)
if type(padding) == int:
padding = [(i, i) for i in _ntuple(n)(padding)]
if type(dilation) == int:
dilation = _ntuple(n)(dilation)
return lax.conv_general_dilated(lhs=input, rhs=weight, window_strides=stride, padding=padding, lhs_dilation=dilation, rhs_dilation=None, dimension_numbers=torch_dims[n], feature_group_count=1, batch_group_count=1, precision=None, preferred_element_type=None)
问题是我找不到使用它的 grad 函数的方法:
init = normal()
rng = PRNGKey(42)
x = init(rng, [128, 3, 224, 224])
k = init(rng, [64, 3, 3, 3])
y = conv(x, k)
grad(conv)(y, k)
这就是我得到的。
ValueError: conv_general_dilated lhs feature dimension size divided by feature_group_count must equal the rhs input feature dimension size, but 64 // 1 != 3.
请帮忙!
当我 运行 您的代码使用最新版本的 jax 和 jaxlib (jax==0.2.22
;jaxlib==0.1.72
) 时,我看到以下错误:
TypeError: Gradient only defined for scalar-output functions. Output had shape: (128, 64, 222, 222).
如果我创建一个使用 conv
的标量输出函数,梯度似乎有效:
result = grad(lambda x, k: conv(x, k).sum())(x, k)
print(result.shape)
# (128, 3, 224, 224)
如果您使用的是旧版本的 JAX,您可以尝试更新到更新的版本——您看到的错误可能是由于已经修复的错误造成的。