Postgres中的递归查询(根据分数相互匹配)

Recursion query in Postgres (mutual match based on score)

我从两个不同的来源获得了记录,目标是在彼此匹配的来源之间创建 link。为此,应用了训练有素的 AI 模型,该模型为从源 A 到每个记录到源 B 的每个记录分配匹配概率分数。table score 然后看起来像这样。

src_a_id    src_b_id    score
-----------------------------
1           foo         0.8
1           bar         0.7
1           baz         0.6
2           foo         0.9
2           bar         0.5
2           baz         0.3

现在我需要从这个 table 中读取最有可能与 ID 为 1src_a 记录匹配的计数器。当你 select 数据与 sql SELECT * FROM score WHERE src_a_id = 1 ORDER BY score DESC; 你会得到这个结果。

src_a_id    src_b_id    score
-----------------------------
1           foo         0.8
1           bar         0.7
1           baz         0.6

这里看起来第一行是我要查找的结果,因此计数器匹配是 src_b 记录,ID 为 foo,相互得分为 0.8,但它不正确.我们可以从另一端查询来验证结果。什么是 ID foosrc_b 的计数器匹配?使用 sql SELECT * FROM score WHERE src_b_id = 'foo' ORDER BY score DESC; 我们得到结果:

src_a_id    src_b_id    score
-----------------------------
2           foo         0.9
1           foo         0.8

从第一个查询看来,src_a id 1 匹配 src_b id foo。 从第二个查询可以看出,前面的结论是错误的,因为 src_b id foosrc_a id 2 匹配,因为这对具有更高的相互分数。

考虑到 table 将有数千条记录,我应该如何编写查询以找到与 ID 为 1src_a 记录的匹配项?

我的第一步是在 Postgres 中搜索一些递归查询,但我发现教程不适合我的用例,老实说,到目前为止,我完全无法弥补任何可用的应用程序。

编辑

创建测试数据的演示语法:

CREATE TABLE score (
    src_a_id integer NOT NULL,
    src_b_id varchar(255) NOT NULL,
    score decimal(3,2) NOT NULL
);

INSERT INTO score (src_a_id, src_b_id, score)
VALUES 
    (1, 'foo', 0.8),
    (1, 'bar', 0.7),
    (1, 'baz', 0.6),
    (2, 'foo', 0.9),
    (2, 'bar', 0.5),    
    (2, 'baz', 0.3);

从测试数据可以推导出存在两对。

如何查询 src_a id 1 匹配项?预期结果是 src_b id bar。从另一边。如何查询 src_b id bar 匹配项?预期结果是 src_a id 1.

您的问题似乎可以通过使用 window 函数 row_number() over(<partition>) 来解决 w/o 递归。你想要的是找到这样的对,其中每个 id 的得分最大。

鉴于您提供的示例数据集 - 我们可以编写此 CTE,其中我们有 2 个行号(每个 id 一个),然后将它们相加以获得一对的排名:

with ranks as (
    select 
        src_a_id, 
        src_b_id, score,
        row_number() over (partition by src_b_id order by score desc) src_b_idx,
        row_number() over (partition by src_a_id order by score desc) 
            + row_number() over (partition by src_b_id order by score desc)  pair_rank
    from score
)

这样你会得到这个结果:

src_a_id  src_b_id  score   pair_rank
-------------------------------------
1         bar        0.7    3
1         baz        0.6    5
1         foo        0.8    3
2         bar        0.5    5
2         baz        0.8    3
2         foo        0.9    2

现在您可以选择 pair_rank 最小的配对

select src_a_id, src_b_id, score from (
    select src_a_id, src_b_id, score, 
        row_number() over (partition by src_a_id order by pair_rank, src_b_idx) as index
    from ranks
) data where index = 1 and <CONDITION> (e.g. src_a_id = <YOUR ID>)

如果没有,查询将生成得分最高的所有对

src_a_id  src_b_id  score
-------------------------
1         bar       0.7
2         foo       0.9

编辑: 在几种极端情况下,上述方法会产生 ambiguous/incorrect 结果:

  1. 给定 src_A_id 的所有对的得分低于共享相同 src_B_id 的任何其他对(如果查询 return null/0 rows/highest全部 src_A_id?)
  2. 具有相同最高分的多对共享 src_B_id(鉴于 src_A_id 不同,哪一对获胜?)
  3. 多个不同 src_A_id 给相同的最高分 src_B_id(同样,鉴于 src_A_id 和 src_B_id 相同,哪一个胜过另一个?)

给定以下数据集,您可以观察到所有 3 种情况:

src_a_id src_b_id  score
------------------------
    1    foo       0.8  |
    1    bar       0.7  | -> all pairs are beanten by some other src_a_id
    1    baz       0.6  |
    2    foo       0.9
    2    bar       0.5
    2    baz       0.8  -> higest for `baz`, but 3.baz has the same score
    3    foo       0.91 |
    3    bar       0.91 | -> both pairs are the higest, but share src_a_id
    3    baz       0.8

这里是改编后的脚本,但您可以根据需要的行为进行调整:

b_rank as (
    select 
        src_a_id, src_b_id,
        rank() over (partition by src_b_id order by score desc) src_b_idx
    from score
)

select src_a_id, src_b_id, score from (
    select 
        rank() over (partition by s.src_a_id order by score desc) score_rank, s.* 
    from score s
    join b_rank b on s.src_a_id = b.src_a_id and s.src_b_id = b.src_b_id and src_b_idx = 1
) data where score_rank = 1 and src_a_id = XXX

产量:

  • null 如果所有对都被采用(示例 src_a_id = 1)
  • 得分最高的一对,即使另一对得分相同 src_b_id(示例 src_a_id = 2)
  • 多行,如果所有这些对在给定相同 src_a_id 的情况下得分最高(示例 src_a_id = 3)