为什么模式为 "max" 的 Pytorch EmbeddingBag 不接受 `per_sample_weights`?

Why does Pytorch EmbeddingBag with mode "max" not accept `per_sample_weights`?

Pytorch 的 EmbeddingBag 允许对不同长度的嵌入索引集合进行高效查找 + 归约操作。 reduce 操作有 3 种模式:“sum”、“average”和“max”。使用“总和”,您还可以提供 per_sample_weights 给您一个加权总和。

为什么 per_sample_weights 不允许进行“最大”操作?查看 how it's implemented,我只能假设在“Mul”操作之后执行“ReduceMean”或“ReduceMax”操作存在问题。这可能与计算梯度有关吗?


p.s: 通过除以权重总和将加权和转化为加权平均值很容易,但是对于“max”,你不能得到这样的加权等价物。

参数 per_sample_weights 仅针对 mode='sum' 实现,不是由于技术限制,而是因为开发人员没有发现“加权最大值”的用例:

I haven't been able to find use cases for "weighted mean" (which can be emulated via weighted sum) and "weighted max".