Dueling DQN 更新模型架构并导致问题

Dueling DQN updates model architecture and causes issues

我使用以下架构创建了一个初始网络模型。

def create_model(env):

    dropout_prob = 0.8 #aggresive dropout regularization
    num_units = 256 #number of neurons in the hidden units
  
    model = Sequential()
    model.add(Flatten(input_shape=(1,) + env.input_shape))
    model.add(Dense(num_units))
    model.add(Activation('relu'))

    model.add(Dense(num_units))
    model.add(Dropout(dropout_prob))
    model.add(Activation('relu'))

    model.add(Dense(env.action_size))
    model.add(Activation('softmax'))
    print(model.summary())
    return model

然后我调用更新网络架构的DQNAgent

dqn = DQNAgent(model=model, nb_actions=env.action_size, memory=memory,
               nb_steps_warmup=settings['train']['warm_up'], 
               target_model_update=settings['train']['update_rate'], policy=policy, enable_dueling_network=True)
dqn.compile(Adam(lr=settings['train']['learning_rate']), metrics=['mse'])

这样做会产生更新的网络架构 - 正如预期的那样。现在的问题是,当我尝试调用这个拟合的新网络时,原来的创建模型函数不能接受保存的模型权重,因为层架构根本不适合。

print(model.summary())

Layer (type)                 Output Shape              Param #   
=================================================================
flatten_49 (Flatten)         (None, 106)               0         
_________________________________________________________________
dense_147 (Dense)            (None, 128)               13696     
_________________________________________________________________
activation_145 (Activation)  (None, 128)               0         
_________________________________________________________________
dense_148 (Dense)            (None, 64)                8256      
_________________________________________________________________
dropout_49 (Dropout)         (None, 64)                0         
_________________________________________________________________
activation_146 (Activation)  (None, 64)                0         
_________________________________________________________________
dense_149 (Dense)            (None, 3)                 195       
_________________________________________________________________
activation_147 (Activation)  (None, 3)                 0         
=================================================================
Total params: 22,147
Trainable params: 22,147
Non-trainable params: 0

print(dqn.model.summary())
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
flatten_49_input (InputLayer (None, 1, 1, 106)         0         
_________________________________________________________________
flatten_49 (Flatten)         (None, 106)               0         
_________________________________________________________________
dense_147 (Dense)            (None, 128)               13696     
_________________________________________________________________
activation_145 (Activation)  (None, 128)               0         
_________________________________________________________________
dense_148 (Dense)            (None, 64)                8256      
_________________________________________________________________
dropout_49 (Dropout)         (None, 64)                0         
_________________________________________________________________
activation_146 (Activation)  (None, 64)                0         
_________________________________________________________________
dense_149 (Dense)            (None, 3)                 195       
_________________________________________________________________
dense_150 (Dense)            (None, 4)                 16        
_________________________________________________________________
lambda_3 (Lambda)            (None, 3)                 0         
=================================================================
Total params: 22,163
Trainable params: 22,163
Non-trainable params: 0
_________________________________________________________________

因此,在不训练新的 dqn 的情况下,我需要找到一种创建网络架构的方法,该网络架构是根据原始架构创建的,但应用了 dqn 模型更改。

最好的方法是用

保存整个模型
dqn.model.save('xyz')

然后在新函数中加载模型而不仅仅是权重,而不是创建原始模型并将其转换为决斗形式。

model = keras.models.load_model('xyz')