tensorflow js如何从另一个模型加载权重
tensorflow js how to load weights from another model
我有两个模型 m1
和 m2
,
我想将模型 m1
的权重更新为 m2
,
在 PyTorch 的 python 中,可以使用以下代码行完成:
m1.load_state_dict(m2.state_dict())
但我在互联网上找不到任何相关信息。
我根据此文档找到的唯一内容:
https://www.tensorflow.org/js/guide/save_load
是通过local-storage保存m2
然后完全加载到m1
,但是我重新下载保存对我来说没有意义所以我可以更新权重。
Loading the weights of another model.
如问题中所示,这可以通过保存第一个模型然后将其加载为另一个模型来完成。
it doesn't make sense to me to download and save it again just so I could update the weights.
如果模型 2 不相同意味着它们具有相同的拓扑结构,则通过模型 1 的权重完全更新模型 2 是没有意义的。没有方法可以直接克隆模型并将其分配给另一个变量。为此,需要将模型作为另一个模型加载,或者将其权重复制并分配给具有相同拓扑结构的另一个模型。
model.getWeight
和model.setWeights
可以用
model2.setWeights(model1.getWeights());
如果模型2要部分更新,即更新某些层的权重,这些答案中已经讨论过 and
所以在更好地阅读文档之后,
我发现了这个:
m1.setWeights(m2.getWeights());
我也试过fit
其中一个,发现它不会学习另一个,但没有问题。
注意它们应该具有相同的结构,完整示例:
const model = tf.sequential();
model.add(tf.layers.dense({ units: 4, inputShape: [8] }));
model.add(tf.layers.dense({ units: 4 }));
model.compile({ optimizer: 'sgd', loss: 'meanSquaredError' });
const model2 = tf.sequential();
model2.add(tf.layers.dense({ units: 4, inputShape: [8] }));
model2.add(tf.layers.dense({ units: 4 }));
model2.compile({ optimizer: 'sgd', loss: 'meanSquaredError' });
model2.setWeights(model.getWeights());
console.log(model.getWeights()[0].dataSync());
console.log(model2.getWeights()[0].dataSync());
我有两个模型 m1
和 m2
,
我想将模型 m1
的权重更新为 m2
,
在 PyTorch 的 python 中,可以使用以下代码行完成:
m1.load_state_dict(m2.state_dict())
但我在互联网上找不到任何相关信息。
我根据此文档找到的唯一内容: https://www.tensorflow.org/js/guide/save_load
是通过local-storage保存m2
然后完全加载到m1
,但是我重新下载保存对我来说没有意义所以我可以更新权重。
Loading the weights of another model.
如问题中所示,这可以通过保存第一个模型然后将其加载为另一个模型来完成。
it doesn't make sense to me to download and save it again just so I could update the weights.
如果模型 2 不相同意味着它们具有相同的拓扑结构,则通过模型 1 的权重完全更新模型 2 是没有意义的。没有方法可以直接克隆模型并将其分配给另一个变量。为此,需要将模型作为另一个模型加载,或者将其权重复制并分配给具有相同拓扑结构的另一个模型。
model.getWeight
和model.setWeights
可以用
model2.setWeights(model1.getWeights());
如果模型2要部分更新,即更新某些层的权重,这些答案中已经讨论过
所以在更好地阅读文档之后,
我发现了这个:
m1.setWeights(m2.getWeights());
我也试过fit
其中一个,发现它不会学习另一个,但没有问题。
注意它们应该具有相同的结构,完整示例:
const model = tf.sequential();
model.add(tf.layers.dense({ units: 4, inputShape: [8] }));
model.add(tf.layers.dense({ units: 4 }));
model.compile({ optimizer: 'sgd', loss: 'meanSquaredError' });
const model2 = tf.sequential();
model2.add(tf.layers.dense({ units: 4, inputShape: [8] }));
model2.add(tf.layers.dense({ units: 4 }));
model2.compile({ optimizer: 'sgd', loss: 'meanSquaredError' });
model2.setWeights(model.getWeights());
console.log(model.getWeights()[0].dataSync());
console.log(model2.getWeights()[0].dataSync());