如何在mxnet中使用像torch.nn.functional.conv2d()这样的函数?
How to use the func like torch.nn.functional.conv2d() in mxnet?
我想用输入数据和内核做一些卷积计算。
在火炬中,我可以写一个函数:
import torch
def torch_conv_func(x, num_groups):
batch_size, num_channels, height, width = x.size()
conv_kernel = torch.ones(num_channels, num_channels, 1, 1)
return torch.nn.functional.conv2d(x, conv_kernel)
它工作得很好,现在我需要在 MXnet 中重建,所以我这样写:
from mxnet import nd
from mxnet.gluon import nn
def mxnet_conv_func(x, num_groups):
batch_size, num_channels, height, width = x.shape
conv_kernel = nd.ones((num_channels, num_channels, 1, 1))
return nd.Convolution(x, conv_kernel)
我得到了错误
mxnet.base.MXNetError: Required parameter kernel of Shape(tuple) is not presented, in operator Convolution(name="")
如何解决?
您缺少 mxnet.nd.Convolution
的一些额外参数。你可以这样做:
from mxnet import nd
def mxnet_convolve(x):
B, C, H, W = x.shape
weight = nd.ones((C, C, 1, 1))
return nd.Convolution(x, weight, no_bias=True, kernel=(1,1), num_filter=C)
x = nd.ones((16, 3, 32, 32))
mxnet_convolve(x)
由于您没有使用偏差,因此需要将 no_bias
设置为 True
。此外,mxnet 要求您使用 kernel
和 num_filter
参数指定内核维度。
我想用输入数据和内核做一些卷积计算。
在火炬中,我可以写一个函数:
import torch
def torch_conv_func(x, num_groups):
batch_size, num_channels, height, width = x.size()
conv_kernel = torch.ones(num_channels, num_channels, 1, 1)
return torch.nn.functional.conv2d(x, conv_kernel)
它工作得很好,现在我需要在 MXnet 中重建,所以我这样写:
from mxnet import nd
from mxnet.gluon import nn
def mxnet_conv_func(x, num_groups):
batch_size, num_channels, height, width = x.shape
conv_kernel = nd.ones((num_channels, num_channels, 1, 1))
return nd.Convolution(x, conv_kernel)
我得到了错误
mxnet.base.MXNetError: Required parameter kernel of Shape(tuple) is not presented, in operator Convolution(name="")
如何解决?
您缺少 mxnet.nd.Convolution
的一些额外参数。你可以这样做:
from mxnet import nd
def mxnet_convolve(x):
B, C, H, W = x.shape
weight = nd.ones((C, C, 1, 1))
return nd.Convolution(x, weight, no_bias=True, kernel=(1,1), num_filter=C)
x = nd.ones((16, 3, 32, 32))
mxnet_convolve(x)
由于您没有使用偏差,因此需要将 no_bias
设置为 True
。此外,mxnet 要求您使用 kernel
和 num_filter
参数指定内核维度。