Tensorflow:根据另一个张量对张量进行采样?

Tensorflow: Sampling a tensor according to another tensor?

我有一个形状为 Batch_Size x Num_Items x Item_Dimension 的张量 T 和另一个形状为 Batch_Size x Num_Items 的张量 P,其中每批 P 中的 Num_Items 值总和为 1(每个批次的项目概率分布)。我想根据概率分布 P 从 T 中不放回 N 项进行采样。结果张量的形状应为 Batch_Size x N x Item_Dimension。我该怎么做?

看看 https://github.com/tensorflow/tensorflow/issues/9260

不过请注意,我相信您需要 logits 而不是 Gumbel 最大采样的概率。