加载到新模型实例后,pytorch state_dict 的序列化发生变化
Serialiazation of pytorch state_dict changes after loading into new model instance
为什么序列化pytorch获得的字节state_dict
在将state_dict
加载到相同模型架构的新实例后发生变化?
看看:
import binascii
import torch.nn as nn
import pickle
lin1 = nn.Linear(1, 1, bias=False)
lin1s = pickle.dumps(lin1.state_dict())
print("--- original model ---")
print(f"hash of state dict: {hex(binascii.crc32(lin1s))}")
print(f"weight: {lin1.state_dict()['weight'].item()}")
lin2 = nn.Linear(1, 1, bias=False)
lin2.load_state_dict(pickle.loads(lin1s))
lin2s = pickle.dumps(lin2.state_dict())
print("\n--- model from deserialized state dict ---")
print(f"hash of state dict: {hex(binascii.crc32(lin2s))}")
print(f"weight: {lin2.state_dict()['weight'].item()}")
打印
--- original model ---
hash of state dict: 0x4806e6b6
weight: -0.30337071418762207
--- model from deserialized state dict ---
hash of state dict: 0xe2881422
weight: -0.30337071418762207
如您所见,state_dict
的(泡菜)的哈希值不同,而权重已正确复制。我假设新模型的 state_dict
在各个方面都等于旧模型。看起来,它没有,因此不同的哈希值。
这可能是因为预计 pickle 不会生成适合散列的 repr(参见 Using pickle.dumps to hash mutable objects)。比较键,然后比较存储在字典键中的张量可能是一个更好的主意 equality/closeness.
下面是这个想法的粗略实现。
def compare_state_dict(dict1, dict2):
# compare keys
for key in dict1:
if key not in dict2:
return False
for key in dict2:
if key not in dict1:
return False
for (k,v) in dict1.items():
if not torch.all(torch.isclose(v, dict2[k]))
return False
return True
但是,如果您仍想散列状态字典并避免使用像上面 isclose
这样的比较,您可以使用如下函数。
def dict_hash(dictionary):
for (k,v) in dictionary.items():
# it did not work without hashing the tensor
dictionary[k] = hash(v)
# dictionaries are not hashable and need to be converted to frozenset.
return hash(frozenset(sorted(dictionary.items(), key=lambda x: x[0])))
为什么序列化pytorch获得的字节state_dict
在将state_dict
加载到相同模型架构的新实例后发生变化?
看看:
import binascii
import torch.nn as nn
import pickle
lin1 = nn.Linear(1, 1, bias=False)
lin1s = pickle.dumps(lin1.state_dict())
print("--- original model ---")
print(f"hash of state dict: {hex(binascii.crc32(lin1s))}")
print(f"weight: {lin1.state_dict()['weight'].item()}")
lin2 = nn.Linear(1, 1, bias=False)
lin2.load_state_dict(pickle.loads(lin1s))
lin2s = pickle.dumps(lin2.state_dict())
print("\n--- model from deserialized state dict ---")
print(f"hash of state dict: {hex(binascii.crc32(lin2s))}")
print(f"weight: {lin2.state_dict()['weight'].item()}")
打印
--- original model ---
hash of state dict: 0x4806e6b6
weight: -0.30337071418762207
--- model from deserialized state dict ---
hash of state dict: 0xe2881422
weight: -0.30337071418762207
如您所见,state_dict
的(泡菜)的哈希值不同,而权重已正确复制。我假设新模型的 state_dict
在各个方面都等于旧模型。看起来,它没有,因此不同的哈希值。
这可能是因为预计 pickle 不会生成适合散列的 repr(参见 Using pickle.dumps to hash mutable objects)。比较键,然后比较存储在字典键中的张量可能是一个更好的主意 equality/closeness.
下面是这个想法的粗略实现。
def compare_state_dict(dict1, dict2):
# compare keys
for key in dict1:
if key not in dict2:
return False
for key in dict2:
if key not in dict1:
return False
for (k,v) in dict1.items():
if not torch.all(torch.isclose(v, dict2[k]))
return False
return True
但是,如果您仍想散列状态字典并避免使用像上面 isclose
这样的比较,您可以使用如下函数。
def dict_hash(dictionary):
for (k,v) in dictionary.items():
# it did not work without hashing the tensor
dictionary[k] = hash(v)
# dictionaries are not hashable and need to be converted to frozenset.
return hash(frozenset(sorted(dictionary.items(), key=lambda x: x[0])))