如何保存神经网络的权重

How to save weights of a neural network

我在将经过训练的神经网络的权重保存在文本文件中时遇到问题。 这是我的代码

def nNetwork(trainingData,filename):

    lamda = 1
    input_layer = 1200
    output_layer = 10
    hidden_layer = 25
    X=trainingData[0]
    y=trainingData[1]
    theta1 = randInitializeWeights(1200,25)
    theta2 = randInitializeWeights(25,10)
    m,n = np.shape(X)
    yk = recodeLabel(y,output_layer)
    theta = np.r_[theta1.T.flatten(), theta2.T.flatten()]

    X_bias = np.r_[np.ones((1,X.shape[0])), X.T]
    #conjugate gradient algo
    result = scipy.optimize.fmin_cg(computeCost,fprime=computeGradient,x0=theta,args=(input_layer,hidden_layer,output_layer,X,y,lamda,yk,X_bias),maxiter=100,disp=True,full_output=True )
    print result[1]  #min value
    theta1,theta2 = paramUnroll(result[0],input_layer,hidden_layer,output_layer)
    counter = 0
    for i in range(m):
        prediction = predict(X[i],theta1,theta2)
        actual = y[i]
        if(prediction == actual):
            counter+=1
    print  str(counter *100/m) + '% accuracy'

    data = {"Theta1":[theta1],
            "Theta2":[theta2]}
    op=open(filename,'w')
    json.dump(data,op)
    op.close()

def paramUnroll(params,input_layer,hidden_layer,labels):
    theta1_elems = (input_layer+1)*hidden_layer
    theta1_size = (input_layer+1,hidden_layer)
    theta2_size = (hidden_layer+1,labels)
    theta1 = params[:theta1_elems].T.reshape(theta1_size).T
    theta2 = params[theta1_elems:].T.reshape(theta2_size).T
    return theta1, theta2

我收到以下错误 提高类型错误(repr(o)+“不是JSON可序列化”)

请提供解决方案或任何其他方法来保存权重,以便我可以轻松地在其他代码中加载它们。

将numpy数组保存为纯文本最简单的方法是执行numpy.savetxt (and load it with numpy.loadtxt). However, if you want to save both using the JSON format you can write the files using a StringIO instance:

with StringIO as theta1IO:
    numpy.savetxt(theta1IO, theta1)
    data = {"theta1": theta1IO.getvalue() }
    # write as JSON as usual

您也可以使用其他参数来做到这一点。

要检索您可以执行的数据:

# read data from JSON
with StringIO as theta1IO:
    theta1IO.write(data['theta1'])
    theta1 = numpy.loadtxt(theta1IO)