pytorch 从 json 文件加载预训练权重

pytorch loading pretrained weights from json file

我下载了一个 json 词嵌入文件。文件内容如下所示:

{
"in":[0.052956,0.065460,0.066195,0.047072,0.052221,-0.082009,-0.061415,-0.116210,0.015629,0.099293,-0.085686,-0.028133,0.052221,0.058840,-0.077596,-0.073550,0.033282,0.077228,-0.045785,-0.027214,-0.034201,0.035672,-0.090835,-0.048175,0.001701,0.027949,-0.002195,0.088628,0.046521,0.048175,0.061047,-0.051853,-0.016089,0.041556,-0.064357,0.051853,-0.096351,-0.025007,0.074286,0.132391,0.083480,-0.026110,-0.035488,-0.006390,0.027030,0.077596,0.020318,-0.021605,-0.003861,0.080170,0.045050,0.070976,0.025375,-0.020410,-0.070976,0.000776,-0.036407,0.025926,0.061047,-0.085318,-0.066931,0.027030,-0.109590,-0.183876,-0.046337,0.039901,0.042843,0.135333,0.045969,0.065460,0.093409,-0.030340,0.017009,0.133862,-0.022341,-0.022341,0.088260,0.023444,-0.072447,0.050014,0.003540,-0.060311,0.047440,-0.015538,-0.041188,-0.102235,-0.047808,0.062886,-0.048175,0.016181,0.058105,-0.027949,-0.025375,-0.138275,-0.054795,0.011952,0.070241,-0.046337,-0.010711,-0.002597,0.008366,-0.119152,-0.012871,0.004666,-0.006574,-0.060679,-0.011492,-0.066195,0.002620,-0.012136,-0.009286,0.073550,-0.105177,-0.064724,-0.020226,0.040637,0.100028,0.084951,0.091202,0.064357,-0.005355,0.033649,-0.109590,-0.002413,-0.088628,-0.049279,0.053692,-0.070976,-0.022801,0.090467,0.060311,-0.071344,-0.122094,-0.058473,0.015997,-0.061415,0.002965,-0.118416,-0.073918,0.029972,0.029604,-0.006849,0.077596,0.051117,-0.032178,0.047808,-0.036959,0.015721,-0.125771,0.070241,0.070608,0.005172,0.040453,0.039533,-0.018388,-0.024455,-0.046337,-0.004183,0.072447,0.028501,0.009194,-0.033098,-0.005631,0.079434,0.015354,0.109590,0.061782,0.004344,0.003448,-0.069873,-0.104441,-0.043211,-0.038798,-0.098557,-0.105177,-0.015446,-0.020410,0.024639,0.079067,-0.001758,-0.017009,0.000379,-0.083480,0.063989,-0.097822,-0.013147,-0.000270,0.081273,0.066931,0.033649,0.018939,0.017928,0.061047,0.017836,-0.082744,0.004045,-0.013331,-0.025559,-0.024823,-0.123565,0.072079,-0.013791,0.003999,-0.025926,-0.033282,-0.050014,-0.013515,-0.022341,-0.005723,-0.038614,-0.040820,0.067299,-0.054059,0.011492,-0.062150,-0.023904,0.026846,-0.015997,-0.044682,-0.009837,0.035304,0.017376,0.015813,-0.059208,-0.006068,0.014710,-0.004183,0.031259,0.020962,0.010251,0.026110,-0.137539,0.090467,0.055898,-0.030891,-0.007493,0.032362,-0.005493,0.092673,0.043395,-0.040269,-0.024272,-0.006849,-0.035120,0.033098,-0.038246,0.051853,0.002252,-0.003149,-0.033282,0.055530,-0.009608,0.050750,0.004735,0.056634,-0.028501,0.003678,0.033649,-0.050750,0.007309,0.003563,0.015446,0.053692,0.128713,0.130920,0.041924,0.068770,-0.028133,0.037511,-0.029604,0.033282,0.047072,0.036591,-0.040085,0.036775,-0.098557,-0.021789,-0.027214,-0.045785,-0.043211,0.092673,-0.062150,-0.008964,0.094144,0.001023,0.048175,-0.080170,-0.108119,-0.031811,0.018112,-0.127242,-0.066931,-0.060679,0.048911,0.046153,-0.035672,-0.044314,-0.035856,0.010895,-0.047072],
"for":[-0.008512,-0.034224,0.032284,0.045868,-0.013143,-0.046221,-0.000948,-0.052219,0.046574,0.062451,-0.122785,-0.028756,0.051513,-0.018700,0.013143,0.098792,0.104438,-0.024345,-0.070566,-0.086796,-0.057511,0.045162,-0.048338,0.053630,0.016407,0.024169,-0.130547,0.037576,0.010012,0.067038,0.002536,-0.006571,-0.070213,0.049043,-0.006351,0.031931,-0.096675,-0.071977,0.023992,0.020200,0.112200,-0.012790,0.010320,-0.079387,-0.061745,-0.052924,-0.017818,0.124902,0.044633,0.064568,-0.017553,0.102321,-0.023816,0.019847,-0.112200,0.005689,-0.051160,0.031578,0.004344,-0.040399,-0.106555,0.020552,-0.095970,-0.127724,-0.065979,-0.036694,-0.018788,-0.107260,-0.058217,0.108672,-0.031402,0.057158,0.023992,0.065274,0.016407,-0.045162,0.118551,0.062098,-0.008953,0.141838,-0.044986,0.016230,-0.021787,0.015348,0.002404,-0.040046,-0.052924,0.021523,0.035989,0.012614,0.075506,0.028050,0.061392,-0.179238,0.050102,-0.107966,0.042163,0.069155,-0.024169,0.045515,0.015436,-0.105143,0.038811,-0.065626,-0.018347,0.032813,0.003837,-0.083621,-0.014113,0.087502,0.023287,0.068449,-0.046574,0.016407,0.087149,0.043574,0.087149,0.035283,0.067391,0.048338,0.021170,-0.024698,-0.080445,0.038635,-0.018524,0.012878,0.044986,-0.018700,0.105143,0.045162,0.077975,-0.117845,-0.070566,-0.076564,-0.061745,-0.064215,0.073036,-0.057511,0.006086,0.017377,0.094558,0.037047,0.058923,0.067743,-0.042340,-0.069860,-0.020464,-0.105143,-0.106555,0.105143,-0.012702,0.023816,-0.061745,-0.007939,-0.026815,-0.009879,0.025933,-0.005954,0.036341,-0.068449,0.034577,0.014995,0.022140,0.093853,0.038106,0.013584,-0.012702,0.025227,0.013231,-0.007145,-0.133370,-0.064921,-0.020993,-0.043927,-0.037047,-0.001709,0.047985,-0.059628,-0.028932,0.069507,-0.111494,-0.110789,0.020464,0.009482,0.021611,-0.008777,-0.069860,0.017906,0.139721,0.009394,0.017465,-0.025933,0.071272,-0.069860,-0.144660,-0.009967,0.062098,-0.057864,-0.127724,-0.126313,0.003705,-0.025227,-0.039517,0.067743,-0.067391,-0.008644,-0.000408,0.070566,0.017906,-0.028756,0.007057,0.085385,0.018612,0.088913,0.046574,0.051160,0.021170,-0.035812,-0.056453,0.020905,0.032990,-0.031049,0.018700,-0.037400,0.101615,0.003087,-0.027344,0.019847,0.043398,0.020464,0.020288,-0.026462,0.094558,-0.000070,-0.050102,-0.015966,0.049043,-0.016848,-0.011070,-0.042163,0.044104,0.000466,0.002889,-0.051513,0.066332,0.018965,0.014466,0.025580,-0.041810,-0.021434,0.019758,0.018171,0.043574,0.095264,-0.003153,0.001974,0.043222,0.071272,-0.066332,-0.033166,-0.012614,0.027697,-0.013849,0.033519,0.034577,0.070919,-0.029108,0.068096,-0.025051,-0.030520,0.050807,-0.009879,0.076917,0.011908,0.095264,-0.001224,-0.006130,-0.103026,-0.033695,-0.079387,0.059275,-0.029638,-0.013672,0.063509,-0.002029,0.172181,-0.034048,-0.016583,0.029461,0.021170,-0.016318,0.002690,-0.059628,0.058923,0.005733,0.000345,0.013319,0.051513,-0.025227,0.017465],
"that":[-0.012361,-0.022230,0.065540,0.039477,-0.086620,0.024913,-0.011163,-0.070522,0.092369,0.092752,-0.056341,-0.060557,-0.054042,0.060557,-0.108850,0.005102,0.008624,-0.011881,-0.000755,-0.023763,-0.000124,0.030087,-0.018972,-0.036028,0.074355,-0.043310,-0.050975,0.004791,0.000671,0.048676,-0.042735,0.011067,0.017439,-0.035261,0.087386,-0.030279,0.040244,0.019739,0.013319,0.049442,0.108083,0.106550,0.051359,-0.050592,-0.018876,-0.010492,-0.029129,0.003378,-0.012361,0.014948,0.085087,0.035070,-0.035261,-0.074738,0.068223,0.064390,0.005366,-0.103484,0.002144,-0.059407,0.017631,0.134912,-0.038136,0.030087,-0.069373,-0.013510,0.017152,0.105017,0.008384,0.039094,0.029895,-0.004120,0.048101,-0.039286,-0.083170,0.043693,0.121115,0.134146,0.037752,0.099651,0.064007,-0.079721,0.034495,-0.010636,-0.105017,-0.123414,0.019068,0.164041,-0.080104,-0.073589,0.038136,0.059024,0.002767,-0.096968,-0.018972,-0.001036,0.030087,0.005965,0.013894,0.034303,-0.077038,-0.045610,0.011067,0.032195,-0.027787,-0.018014,-0.102717,-0.113449,0.022709,-0.096202,-0.055958,-0.005605,-0.075888,0.045993,0.081637,0.020697,0.005941,0.028362,0.031620,0.041394,-0.160208,-0.026254,-0.022805,0.024913,-0.096968,-0.052892,0.012456,-0.067839,0.009821,-0.049442,-0.094669,0.018397,-0.103484,-0.092752,-0.009534,-0.086237,0.074738,-0.032962,0.014373,0.040627,0.011738,-0.124947,-0.017056,-0.004024,0.028171,-0.002383,-0.061324,-0.040244,-0.005821,0.068606,-0.018780,0.034686,-0.089303,0.016864,-0.003006,-0.034111,-0.081637,-0.145644,-0.035261,0.035261,-0.034878,0.014948,-0.016481,0.010588,0.011977,-0.023859,0.036603,0.080487,-0.010875,0.006468,-0.041394,0.015427,-0.059791,-0.070522,0.034495,0.006228,0.009917,-0.085087,-0.014564,-0.082021,-0.119581,-0.062090,-0.022613,-0.014660,0.076271,-0.006564,-0.027787,0.005917,0.045610,0.064390,0.022613,0.040052,0.002491,-0.014564,0.011738,-0.057108,-0.026829,0.034495,-0.038327,-0.126480,0.020122,0.028746,-0.000121,-0.000988,-0.031237,-0.025296,-0.012361,0.047718,0.076271,-0.011786,-0.026446,-0.012025,0.003665,0.025871,-0.064390,0.083554,0.121115,0.006899,-0.094285,0.048101,0.045993,0.030470,-0.012552,-0.034495,0.094285,-0.059024,0.098118,0.027596,0.057108,0.068606,0.016577,-0.057874,0.027021,-0.073972,0.009103,-0.044843,-0.061707,0.012552,0.059407,0.023955,0.003617,-0.114216,-0.019451,-0.084704,0.054042,0.045610,0.098118,-0.051359,0.004144,0.009294,0.054808,0.099651,0.051359,-0.013606,0.093519,-0.025488,0.113449,0.060174,-0.025296,-0.051742,0.049442,-0.049059,-0.075505,0.083554,-0.031237,0.091219,-0.007618,-0.027787,-0.051359,0.046184,0.127247,0.040244,0.124947,0.074738,0.059791,-0.072055,0.019739,-0.061707,0.070139,-0.045993,-0.031428,0.036028,0.024338,0.030662,0.027979,-0.083170,-0.029129,-0.126480,0.016768,0.000958,-0.008863,-0.012265,-0.026254,-0.016193,-0.015235,0.050209,0.015810,0.005390,0.047909,-0.116515],
...

我发现这个函数可以将预训练的嵌入加载到 pytorch 中:

self.embeds = torch.nn.Embedding.from_pretrained(weights)

我的问题是,如何将.json文件加载到上述函数中?我觉得文档没有帮助。来自文档:

CLASSMETHOD from_pretrained(
    embeddings, freeze=True, padding_idx=None, max_norm=None, norm_type=2.0, 
    scale_grad_by_freq=False, sparse=False
)

embeddings (Tensor) – FloatTensor containing weights for the Embedding. 
First dimension is being passed to Embedding as num_embeddings, second as embedding_dim.

如何将此 json 文件转换为适合此函数的格式的“FloatTensor”?

谢谢!

weights = torch.stack([torch.Tensor(value) for _, value in in_json.items()], dim=0)