将 Pytorch Float 模型转换为 Double
Convert Pytorch Float Model into Double
我正在尝试解决健身房的推车问题。事实证明,状态是双精度浮点数,而 pytorch 默认情况下以单浮点精度创建模型。
class QNetworkMLP(Module):
def __init__(self,state_dim,num_actions):
super(QNetworkMLP,self).__init__()
self.l1 = Linear(state_dim,64)
self.l2 = Linear(64,64)
self.l3 = Linear(64,128)
self.l4 = Linear(128,num_actions)
self.relu = ReLU()
self.lrelu = LeakyReLU()
def forward(self,x) :
x = self.lrelu(self.l1(x))
x = self.lrelu(self.l2(x))
x = self.lrelu(self.l3(x))
x = self.l4(x)
return x
我尝试通过
进行转换
model = QNetworkMLP(4,2).double()
但还是不行,我得到同样的错误。
File ".\agent.py", line 117, in update_online_network
predicted_Qval = self.online_network(states_batch).gather(1,actions_batch)
File "C:\Usersabh\anaconda3\envs\gym\lib\site-packages\torch\nn\modules\module.py", line 722, in _call_impl
result = self.forward(*input, **kwargs)
File "C:\Usersabh\Desktop\OpenAI Gym\Cartpole\agent_model.py", line 16, in forward
x = self.lrelu(self.l1(x))
File "C:\Usersabh\anaconda3\envs\gym\lib\site-packages\torch\nn\modules\module.py", line 722, in _call_impl
result = self.forward(*input, **kwargs)
File "C:\Usersabh\anaconda3\envs\gym\lib\site-packages\torch\nn\modules\linear.py", line 91, in forward
return F.linear(input, self.weight, self.bias)
File "C:\Usersabh\anaconda3\envs\gym\lib\site-packages\torch\nn\functional.py", line 1674, in linear
ret = torch.addmm(bias, input, weight.t())
RuntimeError: Expected object of scalar type Double but got scalar type Float for argument #2 'mat1' in call to _th_addmm
你可以在初始化你的模型后试试这个吗:
model.to(torch.double)
另外一定要检查你对模型的所有输入是否都是torch.double数据类型
我正在尝试解决健身房的推车问题。事实证明,状态是双精度浮点数,而 pytorch 默认情况下以单浮点精度创建模型。
class QNetworkMLP(Module):
def __init__(self,state_dim,num_actions):
super(QNetworkMLP,self).__init__()
self.l1 = Linear(state_dim,64)
self.l2 = Linear(64,64)
self.l3 = Linear(64,128)
self.l4 = Linear(128,num_actions)
self.relu = ReLU()
self.lrelu = LeakyReLU()
def forward(self,x) :
x = self.lrelu(self.l1(x))
x = self.lrelu(self.l2(x))
x = self.lrelu(self.l3(x))
x = self.l4(x)
return x
我尝试通过
进行转换model = QNetworkMLP(4,2).double()
但还是不行,我得到同样的错误。
File ".\agent.py", line 117, in update_online_network
predicted_Qval = self.online_network(states_batch).gather(1,actions_batch)
File "C:\Usersabh\anaconda3\envs\gym\lib\site-packages\torch\nn\modules\module.py", line 722, in _call_impl
result = self.forward(*input, **kwargs)
File "C:\Usersabh\Desktop\OpenAI Gym\Cartpole\agent_model.py", line 16, in forward
x = self.lrelu(self.l1(x))
File "C:\Usersabh\anaconda3\envs\gym\lib\site-packages\torch\nn\modules\module.py", line 722, in _call_impl
result = self.forward(*input, **kwargs)
File "C:\Usersabh\anaconda3\envs\gym\lib\site-packages\torch\nn\modules\linear.py", line 91, in forward
return F.linear(input, self.weight, self.bias)
File "C:\Usersabh\anaconda3\envs\gym\lib\site-packages\torch\nn\functional.py", line 1674, in linear
ret = torch.addmm(bias, input, weight.t())
RuntimeError: Expected object of scalar type Double but got scalar type Float for argument #2 'mat1' in call to _th_addmm
你可以在初始化你的模型后试试这个吗:
model.to(torch.double)
另外一定要检查你对模型的所有输入是否都是torch.double数据类型