将 Flux 与结构一起使用:意外的通道维度
Using Flux with structs: unexpected channel dimension
在结构中调用 Flux 函数与直接将函数应用于张量时,我似乎得到了不同的行为(不同的输出维度):
直接申请:
m = Chain(MaxPool((2,2), stride=2),Conv((3,3), 32*8=>32*16, pad=1), BatchNorm(32*16, relu),Conv((3,3), 32*16=>32*16, pad=1), BatchNorm(32*16, relu))
println(size(m(ones((32, 32, 256, 1))))) #gives the expected (16, 16, 512, 1)
通过结构:
block(in_channels, features) = Chain(MaxPool((2,2), stride=2), Conv((3,3), in_channels=>features, pad=1), BatchNorm(features, relu), Conv((3,3), features=>features, pad=1), BatchNorm(features, relu))
struct test
b
end
function test()
b = (block(32*8, 32*16))
test(b)
end
function (t::test)(x)
x1 = t.b[1](x)
println(size(x1))
end
test1 = test()
test1(ones((32, 32, 256, 1))) #gives (16, 16, 256, 1)
为什么 2 个代码段的输出通道尺寸不同?关于 Julia 中的结构,我缺少什么?谢谢!
发现错误,它与索引有关,而不是使用结构。我将 b
声明为 b = (block(32*8, 32*16))
,但通过索引 b[1]
,我实际上只调用了 Flux 链中的第一个操作(MaxPool
),这说明了通道的差异尺寸。我应该做的是 b = block(32*8, 32*16)
和 x1 = t.b(x)
到 运行 链中的所有函数。
定义 (t::test)(x)
函数的正确方法是
function (t::test)(x)
x1 = t.b(x) # Note the absence of [1]
println(size(x1))
end
t.b[1]
将给出 Chain 中的第一层,即 MaxPool((2, 2), pad = (0, 0, 0, 0), stride = (2, 2))
,因此您的输入永远不会通过 Conv 层。
在结构中调用 Flux 函数与直接将函数应用于张量时,我似乎得到了不同的行为(不同的输出维度):
直接申请:
m = Chain(MaxPool((2,2), stride=2),Conv((3,3), 32*8=>32*16, pad=1), BatchNorm(32*16, relu),Conv((3,3), 32*16=>32*16, pad=1), BatchNorm(32*16, relu))
println(size(m(ones((32, 32, 256, 1))))) #gives the expected (16, 16, 512, 1)
通过结构:
block(in_channels, features) = Chain(MaxPool((2,2), stride=2), Conv((3,3), in_channels=>features, pad=1), BatchNorm(features, relu), Conv((3,3), features=>features, pad=1), BatchNorm(features, relu))
struct test
b
end
function test()
b = (block(32*8, 32*16))
test(b)
end
function (t::test)(x)
x1 = t.b[1](x)
println(size(x1))
end
test1 = test()
test1(ones((32, 32, 256, 1))) #gives (16, 16, 256, 1)
为什么 2 个代码段的输出通道尺寸不同?关于 Julia 中的结构,我缺少什么?谢谢!
发现错误,它与索引有关,而不是使用结构。我将 b
声明为 b = (block(32*8, 32*16))
,但通过索引 b[1]
,我实际上只调用了 Flux 链中的第一个操作(MaxPool
),这说明了通道的差异尺寸。我应该做的是 b = block(32*8, 32*16)
和 x1 = t.b(x)
到 运行 链中的所有函数。
定义 (t::test)(x)
函数的正确方法是
function (t::test)(x)
x1 = t.b(x) # Note the absence of [1]
println(size(x1))
end
t.b[1]
将给出 Chain 中的第一层,即 MaxPool((2, 2), pad = (0, 0, 0, 0), stride = (2, 2))
,因此您的输入永远不会通过 Conv 层。