如何访问 pytorch 嵌入查找 table 作为张量

How to access pytorch embeddings lookup table as a tensor

我想用 tensorboard 投影仪展示我的嵌入。我想访问其中一层的嵌入矩阵(查找 table),以便将其写入日志。

我将图层实例化为:

self.embeddings_user = torch.nn.Embedding(30,300)

我正在寻找具有 30 个用户的形状 (30,300) 且嵌入到 300 维的张量,以用此示例代码中的 vectors 变量替换它:

import numpy as np
import tensorflow as tf
import tensorboard as tb
tf.io.gfile = tb.compat.tensorflow_stub.io.gfile
from torch.utils.tensorboard import SummaryWriter

vectors = np.array([[0,0,1], [0,1,0], [1,0,0], [1,1,1]])
metadata = ['001', '010', '100', '111']  # labels
writer = SummaryWriter()
writer.add_embedding(vectors, metadata)
writer.close()

嵌入层具有与查找对应的权重属性 table。您可以通过以下方式访问它。

vectors = self.embeddings_user.weight

所以现在你可以用张量板可视化了。

import numpy as np
import tensorflow as tf
import tensorboard as tb
tf.io.gfile = tb.compat.tensorflow_stub.io.gfile
from torch.utils.tensorboard import SummaryWriter

vectors = self.embeddings_user.weight
metadata = ['001', '010', '100', '111', ...]  # labels
writer = SummaryWriter()
writer.add_embedding(vectors, metadata)
writer.close()