RuntimeError: output with shape [1] doesn't match the broadcast shape [10]

RuntimeError: output with shape [1] doesn't match the broadcast shape [10]

您好,我尝试使用 pytorch 模块制作 RBM 模型代码,但在可见层到隐藏层中遇到了问题。这是问题部分代码。

        h_bias = (self.h_bias.clone()).expand(10)
        v = v.clone().expand(10)
        
        p_h = F.sigmoid(
            F.linear(v, self.W, bias=h_bias)
        )

        sample_h = self.sample_from_p(p_h)
        return p_h, sample_h

每个参数大小都在这里。

h_bias           v                self.W
torch.Size([10]) torch.Size([10]) torch.Size([1, 10])
1                1                2
Traceback (most recent call last):
  File "/Users/bahk_insung/Documents/Github/ecg-dbn/model.py", line 68, in <module>
    v, v1 = rbm(sample_data)
  File "/Users/bahk_insung/miniforge3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/bahk_insung/Documents/Github/ecg-dbn/RBM.py", line 54, in forward
    pre_h1, h1 = self.v_to_h(v)
  File "/Users/bahk_insung/Documents/Github/ecg-dbn/RBM.py", line 36, in v_to_h
    F.linear(v, self.W, bias=h_bias)
  File "/Users/bahk_insung/miniforge3/lib/python3.9/site-packages/torch/nn/functional.py", line 1849, in linear
    return torch._C._nn.linear(input, weight, bias)
RuntimeError: output with shape [1] doesn't match the broadcast shape [10]

我认为尺寸和尺寸不匹配,这就是发生的原因。但我无法得到任何解决方案。请帮帮我。谢谢。

如果您查看 pytorch functional.linear 文档,它会显示权重参数可以是一维或二维:“权重:(out_features, in_features) 或 (in_features)”。由于您的权重是 2D ([1, 10]),这表明您正在尝试创建大小为“1”且输入大小为“10”的输出。线性变换不知道如何将大小为 10 的输入更改为大小为 1 的输出。如果您的权重始终为 [1, N],那么您可以使用挤压将其更改为 1D,如下所示:

F.linear(v, self.W.squeeze(), bias=h_bias)

这将创建大小为 10 的输出。

我用 torch.repeat() 函数解决了。正如曼迪亚斯所说...

you are trying to create an output of size "1" with an input size of "10". The linear transform does not know how to change your inputs of size 10 into an output of size 1.

那是我的问题。所以我像这样改变了权重输入。

w = self.W.clone().repeat(10, 1)

原来,self.W 大小是 [1, 10]。使用 repeat 函数后变为 [10, 10]。输入为 10 大小,输出为 10.

Tbh 不确定这段代码是否正确,但我必须 运行 快速编写这段代码...无论如何谢谢你们。