遍历笛卡尔积的一个子集,其中所有元素都被(接近)平等地选择

Iterate through a subset of a Cartesian product where all elements are selected (near-)equally

我有一大组样本,可以用三个参数来描述(我们称它们为 abc),例如在元组 (a, b, c),其中每个都可以有有限数量的值。例如,a(索引为 0..24)有 25 个可能值,b 有 20 个可能值,c 有 3 个可能值。 abc 的每个组合都在这个数据集中表示,所以在这个例子中,我的数据集有 1500 个样本 (25 × 20 × 3)。

我想从该数据集中随机 select n 个样本的子集(不重复)。但是,此随机样本必须具有 属性,即 abc 的所有可能值均等表示(或尽可能接近等同,如果selected 样本的数量不能被参数可能值的数量整除。

例如,如果我select 100个样本,我希望a的每个值被表示4次,b的每个值被表示5次,并且每个value of c to be represented 33 times (一个值可以表示34次来满足样本总数selected,这个是哪个值并不重要)。我不关心(a, b, c)的具体组合,只要每个参数值出现的总次数是正确的即可。

我目前的实现如下:

import random

n_a = 25
n_b = 20
n_c = 3

n_desired = 100

# generate random ordering for selections
order_a = random.sample(range(n_a), k=n_a)
order_b = random.sample(range(n_b), k=n_b)
order_c = random.sample(range(n_c), k=n_c)

# select random samples
samples = []
for i in range(n_desired):
    idx_a = order_a[i % n_a]
    idx_b = order_b[i % n_b]
    idx_c = order_c[i % n_c]

    samples.append((idx_a, idx_b, idx_c))

(我知道这段代码可以写得有点不同,例如使用列表理解或使用 [=34= 组合 abc 上的所有操作] 而不是 i % n 索引,但我发现这更具可读性,也是因为 abc 具有有意义但与此问题无关的原始名称代码。)

通过生成 abc 的可能值的随机排序并循环遍历它们,我们确保参数值的出现次数永远不会不同超过 1(首先,所有参数值 selected 一次,然后两次,然后三次,等等)

我们可以验证此代码是否达到了预期的结果(所有可能参数值的相等表示 (±1)):

from collections import Counter

count_a = Counter()
count_b = Counter()
count_c = Counter()

count_a.update(sample[0] for sample in samples)
count_b.update(sample[1] for sample in samples)
count_c.update(sample[2] for sample in samples)

print(f'a values are represented between {min(count_a.values())} and {max(count_a.values())} times')
print(f'b values are represented between {min(count_b.values())} and {max(count_b.values())} times')
print(f'c values are represented between {min(count_c.values())} and {max(count_c.values())} times')

这将打印以下结果:

a values are represented between 4 and 4 times
b values are represented between 5 and 5 times
c values are represented between 33 and 34 times

我们还可以使用集合的 属性 验证此代码不会 select abc 的重复组合他们不允许重复值:

print(len(set(samples)))

这会打印 100,匹配 n_desired 的值。

然而,此实现的一个问题是它仅在 n_desired ≤ lcm(n_a, n_b, n_c) 时有效,其中 lcm() 是最小公倍数(可被 n_an_bn_c 整除的最小正整数)。在我们的示例中,lcm(n_a, n_b, n_c) = lcm(25, 20, 3) = 300。如果我们 运行 上述实现与 n_desired > 300,我们将看到 selected 样本以 300 的周期重复。这是不希望的,因为这忽略了原始数据集的 80%,并且不允许我们 select 更多超过 300 个独特的样本。

一个简单的解决方案是确保 lcm(n_a, n_b, n_c) = n_a × n_b × n_c,如果这三个都是质数,情况就会如此。但是,我希望此算法适用于任何值,部分原因是我无法确保所有值都是质数(例如,在我的应用程序中,n_a 始终是整数平方的结果)。

使用 itertools.product(range(n_a), range(n_b), range(n_c)) 简单地生成一个列表为我提供了所有可能的组合,但这些组合是按顺序排列的,并且通过洗牌这个完整列表和 selecting 第一个 n_desired 样本,我们失去了所有可能参数值的相等表示的 属性。

这就是我陷入困境的地方,因为我对组合数学的了解不足以解决这个问题,也不知道我需要搜索哪些术语才能找到解决方案。我该如何解决这个问题?

您可以生成 abc 的所有随机值(每个值的列表 n_desired),然后将它们组合成一个数组。

import random

# n is the maximal value to generate
# k is the number of samples, i.e. length of the resulting list
def generate(n, k):
    # the values that are evenly distributed
    l1 = list(range(n)) * (k // n)
    # remaining values that are generated one time more than another ones
    l2 = random.sample(range(n), k % n)
    l = l1 + l2
    random.shuffle(l)
    return l
    
n_a = 25
n_b = 20
n_c = 3
n_desired = 100
l = list(zip(generate(n_a, n_desired), generate(n_b, n_desired), generate(n_c, n_desired)))
print(l)

可以删除重复项,然后使用此功能重新采样。首先,它将列表拆分为唯一和重复的样本。然后它会尝试重新排列重复值,以便生成新的唯一样本。如果它不能减少重复的数量,那么它会尝试删除一些独特的样本并使用它们来生成新样本。

def remove_duplicates(l):
    unique = set()
    duplicates = []
    for t in l:
        if t in unique:
            duplicates.append(t)
        else:
            unique.add(t)
    n_duplicates = len(duplicates)
    
    # iterations = 0
    # n_retries = 0
    while n_duplicates > 0:
        while n_duplicates > 0:
            # iterations += 1
            # print(n_duplicates)
            a, b, c = map(list, zip(*(duplicates)))
            for x in a, b, c:
                random.shuffle(x)
            duplicates = []
            for t in zip(a, b, c):
                if t in unique:
                    duplicates.append(t)
                else:
                    unique.add(t)
            if len(duplicates) == n_duplicates:
                break
            n_duplicates = len(duplicates)
        if n_duplicates > 0:
            # n_retries += 1
            n_recycled = min(n_duplicates, len(unique))
            recycled = random.sample(list(unique), n_recycled)
            unique = unique - set(recycled)
            duplicates += recycled
    # print(iterations, n_retries)
    return unique

如果 n_desired 小于所有可能样本的一半 (n_a * n_b * n_c),则效果很好,否则需要大量迭代才能完成。这个问题可以通过生成不包含在最终集中的样本来解决:

if n_desired <= n_a * n_b * n_c // 2:
    result = generate_samples(n_a, n_b, n_c, n_desired)
else:
    excluded = generate_samples(n_a, n_b, n_c, n_a * n_b * n_c - n_desired)
    all_samples = set(itertools.product(range(n_a), range(n_b), range(n_c)))
    result = all_samples - set(excluded)

这比我预期的要繁琐得多,我最终得到了相当多的代码:

from math import prod
from itertools import product
from random import shuffle

def sample(n, ns):
    # make sure parameters are valid
    if n > prod(ns):
        raise ValueError("more values requested than unique combinations", n, ns)

    # "remain" keeps track of the remaining counts for each item
    remain = []
    for n_i in ns:
        k, m = divmod(n, n_i)
        # start with the whole number
        d = {i: k for i in range(n_i)}
        # add in the remainders
        if m:
            r = list(range(n_i))
            shuffle(r)
            for i in r[:m]:
                d[i] += 1
        # sanity check
        assert(sum(d.values()) == n)

        remain.append(d)

    # generate list of all available options in random order
    opts = list(product(*(range(n_i) for n_i in ns)))
    shuffle(opts)

    result = []
    for _ in range(n):
        # get next random item, fails if we've been unlucky
        tup = opts.pop()
        result.append(tup)
        
        # keep track of remaining counts
        for i, (rem, a) in enumerate(zip(remain, tup)):
            j = rem[a]
            if j > 1:
                rem[a] = j - 1
            else:
                del rem[a]
                # remove options that involve a number that's been used up
                opts[:] = filter(lambda t: t[i] != a, opts)

    # we're done
    return result

可用作:

x = sample(100, (25, 20, 3))

请注意,这首先要生成所有可能的选项。这似乎是对您的参数的合理权衡,但如果有数十亿种可能的选择,您不应使用此算法。

另请注意,较大的 ns 会导致此算法失败,请参见下图。

随时提出改进建议,或者将其放入循环中重试 IndexError