TypeError: forward() got an unexpected keyword argument 'baseline value'. How do I correctly load a saved model in Skorch?

TypeError: forward() got an unexpected keyword argument 'baseline value'. How do I correctly load a saved model in Skorch?

我使用以下代码保存了我的 Skorch 神经网络模型:

net_b = NeuralNetClassifier(
    Classifier_b,
    max_epochs=50,
    optimizer__momentum= 0.9,
    lr=0.1,
    device=device,
)

#Fit the model on the full data
net_b.fit(merged_X_train, merged_Y_train);

#Test saving
import pickle
with open('MLP.pkl', 'wb') as f:
    pickle.dump(net_b, f)

当我尝试再次加载此模型并根据测试数据 运行 时,我收到以下错误:

TypeError: forward() got an unexpected keyword argument 'baseline value'

这是我的代码:

#Split the data
X_train, y_train, X_valid, y_valid,X_test, y_test = train_valid_test_split(rescaled_data, target = 'fetal_health',
                                        train_size=0.8, valid_size=0.1, test_size=0.1)

input_dim = f_df_toscale.shape[1]
output_dim = len(np.unique(f_target))
hidden_dim_a = 20
hidden_dim_b = 12
device = 'cpu'

class Classifier_b(nn.Module):
    def __init__(self,
                 input_dim = input_dim,
                 hidden_dim_a = hidden_dim_b,
                 output_dim = output_dim):
        
        super(Classifier_b, self).__init__()

        #Take the inputs and pass these to a hidden layer
        self.hidden = nn.Linear(input_dim,hidden_dim_b)
        
        #Take the hidden layer and pass it through an additional hidden layer
        self.hidden_b = nn.Linear(hidden_dim_a,hidden_dim_b)
        
        #Take the hidden layer and pass to a multi nerouon output
        self.output = nn.Linear(hidden_dim_b,output_dim)

    def forward(self, x):
        hidden = F.relu(self.hidden(x))
        hidden = F.relu(self.hidden_b(hidden))
        output = F.softmax(self.output(hidden))     
        return output

#load the model
with open('MLP.pkl', 'rb') as f:
    model_MLP = pickle.load(f)

#Test the model
y_pred = model_MLP.predict(X_test)
ML = accuracy_score(y_test, y_pred)
print('The accuracy score for the MLP is ', ML)

当我运行这个型号正常在原来的笔记本里什么都运行没问题。但是当我尝试从保存的状态加载我的模型时,我得到了错误。知道为什么吗?我什么都叫'baseline value'.

谢谢

如果代码更改,保存和加载模型可能会出现问题。所以最好用

save_params()load_params()

你的情况

net_b.save_params(f_params='some-file.pkl')

加载模型先初始化(初始化很重要)再加载参数

new_net.initialize()

new_net.load_params(f_params='some-file.pkl')