tensorflow embedding_lookup 是可微的吗?

Is tensorflow embedding_lookup differentiable?

我遇到的一些教程描述了使用随机初始化的嵌入矩阵,然后使用 tf.nn.embedding_lookup 函数获取整数序列的嵌入。我的印象是,由于 embedding_matrix 是通过 tf.get_variable 获得的,优化器会添加适当的 ops 来更新它。

我不明白的是反向传播是如何通过查找函数发生的,它似乎是硬的而不是软的。这个操作的梯度是多少?它的输入 ID 之一?

嵌入矩阵查找在数学上等同于单热编码矩阵的点积(参见 ),这是一种平滑的线性运算。

例如,这是对索引 3:

的查找

这里是梯度的公式:

... 其中左侧是负对数似然的导数(即 objective 函数),x 是输入词,W 是嵌入矩阵,delta 是误差信号。

tf.nn.embedding_lookup 已优化,因此不会发生单热编码转换,但反向传播根据相同的公式工作。