有没有一种有效的方法可以在给定总和或平均值的范围内生成 N 个随机整数?
Is there an efficient way to generate N random integers in a range that have a given sum or average?
有没有一种有效的方法来生成 N 个整数的随机组合,使得——
- 每个整数都在区间[
min
,max
],
- 整数的总和为
sum
,
- 整数可以以任何顺序出现(例如,随机顺序),并且
- 从满足其他要求的所有组合中随机统一选择组合?
是否有类似的随机组合算法,其中整数必须按其值的排序顺序(而不是任何顺序)出现?
(如果sum = N * mean
,选择一个均值为mean
的合适组合是一个特例。这个问题相当于生成一个sum
的均匀随机划分成N个部分每个都在区间 [min
、max
] 中,并以任意顺序出现或按它们的值排序,视情况而定。)
我知道对于以随机顺序出现的组合,可以通过以下方式解决此问题(编辑 [4 月 27 日]:算法已修改。):
如果N * max < sum
或者N * min > sum
,无解
如果N * max == sum
,只有一个解,所有N
个数都等于max
。如果N * min == sum
,只有一个解,所有N
个数都等于min
。
Use the algorithm Smith and Tromble ("Sampling from the Unit Simplex", 2004) 给出的生成 N 个随机非负整数的总和 sum - N * min
.
将min
添加到以这种方式生成的每个数字。
如果任何数字大于max
,则转到步骤3。
但是,如果 max
远小于 sum
,则此算法会很慢。例如,根据我的测试(上面涉及 mean
的特殊情况的实现),该算法平均拒绝 -
- 如果
N = 7, min = 3, max = 10, sum = 42
,大约 1.6 个样本,但是
- 大约 30.6 个样本,如果
N = 20, min = 3, max = 10, sum = 120
。
有没有办法修改此算法以使其对大 N 有效,同时仍满足上述要求?
编辑:
作为评论中建议的替代方案,产生有效随机组合(满足除最后一个要求之外的所有要求)的有效方法是:
- 计算
X
,给定 sum
、min
和 max
的可能有效组合数。
- 选择
Y
,[0, X)
中的均匀随机整数。
- 将 ("unrank")
Y
转换为有效组合。
但是,有没有计算有效组合(或排列)数的公式,有没有办法将整数转换为有效组合? [编辑(4 月 28 日):排列相同而不是组合]。
编辑(4 月 27 日):
在阅读了 Devroye 的 Non-Uniform Random Variate Generation (1986) 之后,我可以确认这是一个生成随机分区的问题。此外,第 661 页的练习 2(尤其是 E 部分)与此问题相关。
编辑(4 月 28 日):
事实证明,我给出的算法是统一的,其中涉及的整数以 随机顺序 给出,而不是 按其值排序的顺序 。由于这两个问题都是普遍感兴趣的,所以我修改了这个问题以寻求这两个问题的规范答案。
以下Ruby代码可用于验证潜在的一致性解决方案(其中algorithm(...)
是候选算法):
combos={}
permus={}
mn=0
mx=6
sum=12
for x in mn..mx
for y in mn..mx
for z in mn..mx
if x+y+z==sum
permus[[x,y,z]]=0
end
if x+y+z==sum and x<=y and y<=z
combos[[x,y,z]]=0
end
end
end
end
3000.times {|x|
f=algorithm(3,sum,mn,mx)
combos[f.sort]+=1
permus[f]+=1
}
p combos
p permus
编辑(4 月 29 日):重新添加 Ruby 当前实现的代码。
Ruby中给出了以下代码示例,但我的问题与编程语言无关:
def posintwithsum(n, total)
raise if n <= 0 or total <=0
ls = [0]
ret = []
while ls.length < n
c = 1+rand(total-1)
found = false
for j in 1...ls.length
if ls[j] == c
found = true
break
end
end
if found == false;ls.push(c);end
end
ls.sort!
ls.push(total)
for i in 1...ls.length
ret.push(ls[i] - ls[i - 1])
end
return ret
end
def integersWithSum(n, total)
raise if n <= 0 or total <=0
ret = posintwithsum(n, total + n)
for i in 0...ret.length
ret[i] = ret[i] - 1
end
return ret
end
# Generate 100 valid samples
mn=3
mx=10
sum=42
n=7
100.times {
while true
pp=integersWithSum(n,sum-n*mn).map{|x| x+mn }
if !pp.find{|x| x>mx }
p pp; break # Output the sample and break
end
end
}
我还没有测试过这个,所以它不是一个真正的答案,只是尝试一下太长而无法放入评论中。从满足前两个条件的数组开始并使用它,使其仍然满足前两个条件,但更加随机。
如果平均值是一个整数,那么你的初始数组可以是 [4, 4, 4, ... 4] 或者 [3, 4, 5, 3, 4, 5, ... 5, 8, 0] 或类似的简单内容。对于 4.5 的平均值,请尝试 [4, 5, 4, 5, ... 4, 5].
接下来在数组中选择一对数字,num1
和 num2
。可能第一个数字应该按顺序取,就像 Fisher-Yates 洗牌一样,第二个数字应该随机选择。按顺序取第一个号码可确保每个号码至少被选中一次。
现在计算max-num1
和num2-min
。这些是从两个数字到 max
和 min
边界的距离。将 limit
设置为两个距离中较小的一个。这是允许的最大更改,不会使一个或另一个数字超出允许的限制。如果 limit
为零,则跳过这对。
在 [1, limit
] 范围内随机选择一个整数:称其为 change
。我从可选范围中省略了 0,因为它没有任何效果。测试可能会表明通过包含它可以获得更好的随机性;我不确定。
现在设置 num1 <- num1 + change
和 num2 <- num2 - change
。这不会影响平均值,并且数组的所有元素仍在要求的边界内。
您将需要 运行 至少遍历整个数组一次。测试应该显示您是否需要 运行 多次通过它来获得足够随机的东西。
预计到达时间:包括伪代码
// Set up the array.
resultAry <- new array size N
for (i <- 0 to N-1)
// More complex initial setup schemes are possible here.
resultAry[i] <- mean
rof
// Munge the array entries.
for (ix1 <- 0 to N-1) // ix1 steps through the array in order.
// Pick second entry different from first.
repeat
ix2 <- random(0, N-1)
until (ix2 != ix1)
// Calculate size of allowed change.
hiLimit <- max - resultAry[ix1]
loLimit <- resultAry[ix2] - min
limit <- minimum(hiLimit, loLimit)
if (limit == 0)
// No change possible so skip.
continue loop with next ix1
fi
// Change the two entries keeping same mean.
change <- random(1, limit) // Or (0, limit) possibly.
resultAry[ix1] <- resultAry[ix1] + change
resultAry[ix2] <- resultAry[ix2] - change
rof
// Check array has been sufficiently munged.
if (resultAry not random enough)
munge the array again
fi
这是我在 Java 中的解决方案。它功能齐全,包含两个生成器:PermutationPartitionGenerator
用于未排序的分区,CombinationPartitionGenerator
用于排序的分区。您的生成器也在 class SmithTromblePartitionGenerator
中实现以供比较。 class SequentialEnumerator
按顺序枚举所有可能的分区(未排序或排序,取决于参数)。我已经为所有这些生成器添加了全面的测试(包括您的测试用例)。
在大多数情况下,实现是不言自明的。有什么问题,过几天我会回复的。
import java.util.Random;
import java.util.function.Supplier;
public abstract class PartitionGenerator implements Supplier<int[]>{
public static final Random rand = new Random();
protected final int numberCount;
protected final int min;
protected final int range;
protected final int sum; // shifted sum
protected final boolean sorted;
protected PartitionGenerator(int numberCount, int min, int max, int sum, boolean sorted) {
if (numberCount <= 0)
throw new IllegalArgumentException("Number count should be positive");
this.numberCount = numberCount;
this.min = min;
range = max - min;
if (range < 0)
throw new IllegalArgumentException("min > max");
sum -= numberCount * min;
if (sum < 0)
throw new IllegalArgumentException("Sum is too small");
if (numberCount * range < sum)
throw new IllegalArgumentException("Sum is too large");
this.sum = sum;
this.sorted = sorted;
}
// Whether this generator returns sorted arrays (i.e. combinations)
public final boolean isSorted() {
return sorted;
}
public interface GeneratorFactory {
PartitionGenerator create(int numberCount, int min, int max, int sum);
}
}
import java.math.BigInteger;
// Permutations with repetition (i.e. unsorted vectors) with given sum
public class PermutationPartitionGenerator extends PartitionGenerator {
private final double[][] distributionTable;
public PermutationPartitionGenerator(int numberCount, int min, int max, int sum) {
super(numberCount, min, max, sum, false);
distributionTable = calculateSolutionCountTable();
}
private double[][] calculateSolutionCountTable() {
double[][] table = new double[numberCount + 1][sum + 1];
BigInteger[] a = new BigInteger[sum + 1];
BigInteger[] b = new BigInteger[sum + 1];
for (int i = 1; i <= sum; i++)
a[i] = BigInteger.ZERO;
a[0] = BigInteger.ONE;
table[0][0] = 1.0;
for (int n = 1; n <= numberCount; n++) {
double[] t = table[n];
for (int s = 0; s <= sum; s++) {
BigInteger z = BigInteger.ZERO;
for (int i = Math.max(0, s - range); i <= s; i++)
z = z.add(a[i]);
b[s] = z;
t[s] = z.doubleValue();
}
// swap a and b
BigInteger[] c = b;
b = a;
a = c;
}
return table;
}
@Override
public int[] get() {
int[] p = new int[numberCount];
int s = sum; // current sum
for (int i = numberCount - 1; i >= 0; i--) {
double t = rand.nextDouble() * distributionTable[i + 1][s];
double[] tableRow = distributionTable[i];
int oldSum = s;
// lowerBound is introduced only for safety, it shouldn't be crossed
int lowerBound = s - range;
if (lowerBound < 0)
lowerBound = 0;
s++;
do
t -= tableRow[--s];
// s can be equal to lowerBound here with t > 0 only due to imprecise subtraction
while (t > 0 && s > lowerBound);
p[i] = min + (oldSum - s);
}
assert s == 0;
return p;
}
public static final GeneratorFactory factory = (numberCount, min, max,sum) ->
new PermutationPartitionGenerator(numberCount, min, max, sum);
}
import java.math.BigInteger;
// Combinations with repetition (i.e. sorted vectors) with given sum
public class CombinationPartitionGenerator extends PartitionGenerator {
private final double[][][] distributionTable;
public CombinationPartitionGenerator(int numberCount, int min, int max, int sum) {
super(numberCount, min, max, sum, true);
distributionTable = calculateSolutionCountTable();
}
private double[][][] calculateSolutionCountTable() {
double[][][] table = new double[numberCount + 1][range + 1][sum + 1];
BigInteger[][] a = new BigInteger[range + 1][sum + 1];
BigInteger[][] b = new BigInteger[range + 1][sum + 1];
double[][] t = table[0];
for (int m = 0; m <= range; m++) {
a[m][0] = BigInteger.ONE;
t[m][0] = 1.0;
for (int s = 1; s <= sum; s++) {
a[m][s] = BigInteger.ZERO;
t[m][s] = 0.0;
}
}
for (int n = 1; n <= numberCount; n++) {
t = table[n];
for (int m = 0; m <= range; m++)
for (int s = 0; s <= sum; s++) {
BigInteger z;
if (m == 0)
z = a[0][s];
else {
z = b[m - 1][s];
if (m <= s)
z = z.add(a[m][s - m]);
}
b[m][s] = z;
t[m][s] = z.doubleValue();
}
// swap a and b
BigInteger[][] c = b;
b = a;
a = c;
}
return table;
}
@Override
public int[] get() {
int[] p = new int[numberCount];
int m = range; // current max
int s = sum; // current sum
for (int i = numberCount - 1; i >= 0; i--) {
double t = rand.nextDouble() * distributionTable[i + 1][m][s];
double[][] tableCut = distributionTable[i];
if (s < m)
m = s;
s -= m;
while (true) {
t -= tableCut[m][s];
// m can be 0 here with t > 0 only due to imprecise subtraction
if (t <= 0 || m == 0)
break;
m--;
s++;
}
p[i] = min + m;
}
assert s == 0;
return p;
}
public static final GeneratorFactory factory = (numberCount, min, max, sum) ->
new CombinationPartitionGenerator(numberCount, min, max, sum);
}
import java.util.*;
public class SmithTromblePartitionGenerator extends PartitionGenerator {
public SmithTromblePartitionGenerator(int numberCount, int min, int max, int sum) {
super(numberCount, min, max, sum, false);
}
@Override
public int[] get() {
List<Integer> ls = new ArrayList<>(numberCount + 1);
int[] ret = new int[numberCount];
int increasedSum = sum + numberCount;
while (true) {
ls.add(0);
while (ls.size() < numberCount) {
int c = 1 + rand.nextInt(increasedSum - 1);
if (!ls.contains(c))
ls.add(c);
}
Collections.sort(ls);
ls.add(increasedSum);
boolean good = true;
for (int i = 0; i < numberCount; i++) {
int x = ls.get(i + 1) - ls.get(i) - 1;
if (x > range) {
good = false;
break;
}
ret[i] = x;
}
if (good) {
for (int i = 0; i < numberCount; i++)
ret[i] += min;
return ret;
}
ls.clear();
}
}
public static final GeneratorFactory factory = (numberCount, min, max, sum) ->
new SmithTromblePartitionGenerator(numberCount, min, max, sum);
}
import java.util.Arrays;
// Enumerates all partitions with given parameters
public class SequentialEnumerator extends PartitionGenerator {
private final int max;
private final int[] p;
private boolean finished;
public SequentialEnumerator(int numberCount, int min, int max, int sum, boolean sorted) {
super(numberCount, min, max, sum, sorted);
this.max = max;
p = new int[numberCount];
startOver();
}
private void startOver() {
finished = false;
int unshiftedSum = sum + numberCount * min;
fillMinimal(0, Math.max(min, unshiftedSum - (numberCount - 1) * max), unshiftedSum);
}
private void fillMinimal(int beginIndex, int minValue, int fillSum) {
int fillRange = max - minValue;
if (fillRange == 0)
Arrays.fill(p, beginIndex, numberCount, max);
else {
int fillCount = numberCount - beginIndex;
fillSum -= fillCount * minValue;
int maxCount = fillSum / fillRange;
int maxStartIndex = numberCount - maxCount;
Arrays.fill(p, maxStartIndex, numberCount, max);
fillSum -= maxCount * fillRange;
Arrays.fill(p, beginIndex, maxStartIndex, minValue);
if (fillSum != 0)
p[maxStartIndex - 1] = minValue + fillSum;
}
}
@Override
public int[] get() { // returns null when there is no more partition, then starts over
if (finished) {
startOver();
return null;
}
int[] pCopy = p.clone();
if (numberCount > 1) {
int i = numberCount;
int s = p[--i];
while (i > 0) {
int x = p[--i];
if (x == max) {
s += x;
continue;
}
x++;
s--;
int minRest = sorted ? x : min;
if (s < minRest * (numberCount - i - 1)) {
s += x;
continue;
}
p[i++]++;
fillMinimal(i, minRest, s);
return pCopy;
}
}
finished = true;
return pCopy;
}
public static final GeneratorFactory permutationFactory = (numberCount, min, max, sum) ->
new SequentialEnumerator(numberCount, min, max, sum, false);
public static final GeneratorFactory combinationFactory = (numberCount, min, max, sum) ->
new SequentialEnumerator(numberCount, min, max, sum, true);
}
import java.util.*;
import java.util.function.BiConsumer;
import PartitionGenerator.GeneratorFactory;
public class Test {
private final int numberCount;
private final int min;
private final int max;
private final int sum;
private final int repeatCount;
private final BiConsumer<PartitionGenerator, Test> procedure;
public Test(int numberCount, int min, int max, int sum, int repeatCount,
BiConsumer<PartitionGenerator, Test> procedure) {
this.numberCount = numberCount;
this.min = min;
this.max = max;
this.sum = sum;
this.repeatCount = repeatCount;
this.procedure = procedure;
}
@Override
public String toString() {
return String.format("=== %d numbers from [%d, %d] with sum %d, %d iterations ===",
numberCount, min, max, sum, repeatCount);
}
private static class GeneratedVector {
final int[] v;
GeneratedVector(int[] vect) {
v = vect;
}
@Override
public int hashCode() {
return Arrays.hashCode(v);
}
@Override
public boolean equals(Object obj) {
if (this == obj)
return true;
return Arrays.equals(v, ((GeneratedVector)obj).v);
}
@Override
public String toString() {
return Arrays.toString(v);
}
}
private static final Comparator<Map.Entry<GeneratedVector, Integer>> lexicographical = (e1, e2) -> {
int[] v1 = e1.getKey().v;
int[] v2 = e2.getKey().v;
int len = v1.length;
int d = len - v2.length;
if (d != 0)
return d;
for (int i = 0; i < len; i++) {
d = v1[i] - v2[i];
if (d != 0)
return d;
}
return 0;
};
private static final Comparator<Map.Entry<GeneratedVector, Integer>> byCount =
Comparator.<Map.Entry<GeneratedVector, Integer>>comparingInt(Map.Entry::getValue)
.thenComparing(lexicographical);
public static int SHOW_MISSING_LIMIT = 10;
private static void checkMissingPartitions(Map<GeneratedVector, Integer> map, PartitionGenerator reference) {
int missingCount = 0;
while (true) {
int[] v = reference.get();
if (v == null)
break;
GeneratedVector gv = new GeneratedVector(v);
if (!map.containsKey(gv)) {
if (missingCount == 0)
System.out.println(" Missing:");
if (++missingCount > SHOW_MISSING_LIMIT) {
System.out.println(" . . .");
break;
}
System.out.println(gv);
}
}
}
public static final BiConsumer<PartitionGenerator, Test> distributionTest(boolean sortByCount) {
return (PartitionGenerator gen, Test test) -> {
System.out.print("\n" + getName(gen) + "\n\n");
Map<GeneratedVector, Integer> combos = new HashMap<>();
// There's no point of checking permus for sorted generators
// because they are the same as combos for them
Map<GeneratedVector, Integer> permus = gen.isSorted() ? null : new HashMap<>();
for (int i = 0; i < test.repeatCount; i++) {
int[] v = gen.get();
if (v == null && gen instanceof SequentialEnumerator)
break;
if (permus != null) {
permus.merge(new GeneratedVector(v), 1, Integer::sum);
v = v.clone();
Arrays.sort(v);
}
combos.merge(new GeneratedVector(v), 1, Integer::sum);
}
Set<Map.Entry<GeneratedVector, Integer>> sortedEntries = new TreeSet<>(
sortByCount ? byCount : lexicographical);
System.out.println("Combos" + (gen.isSorted() ? ":" : " (don't have to be uniform):"));
sortedEntries.addAll(combos.entrySet());
for (Map.Entry<GeneratedVector, Integer> e : sortedEntries)
System.out.println(e);
checkMissingPartitions(combos, test.getGenerator(SequentialEnumerator.combinationFactory));
if (permus != null) {
System.out.println("\nPermus:");
sortedEntries.clear();
sortedEntries.addAll(permus.entrySet());
for (Map.Entry<GeneratedVector, Integer> e : sortedEntries)
System.out.println(e);
checkMissingPartitions(permus, test.getGenerator(SequentialEnumerator.permutationFactory));
}
};
}
public static final BiConsumer<PartitionGenerator, Test> correctnessTest =
(PartitionGenerator gen, Test test) -> {
String genName = getName(gen);
for (int i = 0; i < test.repeatCount; i++) {
int[] v = gen.get();
if (v == null && gen instanceof SequentialEnumerator)
v = gen.get();
if (v.length != test.numberCount)
throw new RuntimeException(genName + ": array of wrong length");
int s = 0;
if (gen.isSorted()) {
if (v[0] < test.min || v[v.length - 1] > test.max)
throw new RuntimeException(genName + ": generated number is out of range");
int prev = test.min;
for (int x : v) {
if (x < prev)
throw new RuntimeException(genName + ": unsorted array");
s += x;
prev = x;
}
} else
for (int x : v) {
if (x < test.min || x > test.max)
throw new RuntimeException(genName + ": generated number is out of range");
s += x;
}
if (s != test.sum)
throw new RuntimeException(genName + ": wrong sum");
}
System.out.format("%30s : correctness test passed%n", genName);
};
public static final BiConsumer<PartitionGenerator, Test> performanceTest =
(PartitionGenerator gen, Test test) -> {
long time = System.nanoTime();
for (int i = 0; i < test.repeatCount; i++)
gen.get();
time = System.nanoTime() - time;
System.out.format("%30s : %8.3f s %10.0f ns/test%n", getName(gen), time * 1e-9, time * 1.0 / test.repeatCount);
};
public PartitionGenerator getGenerator(GeneratorFactory factory) {
return factory.create(numberCount, min, max, sum);
}
public static String getName(PartitionGenerator gen) {
String name = gen.getClass().getSimpleName();
if (gen instanceof SequentialEnumerator)
return (gen.isSorted() ? "Sorted " : "Unsorted ") + name;
else
return name;
}
public static GeneratorFactory[] factories = { SmithTromblePartitionGenerator.factory,
PermutationPartitionGenerator.factory, CombinationPartitionGenerator.factory,
SequentialEnumerator.permutationFactory, SequentialEnumerator.combinationFactory };
public static void main(String[] args) {
Test[] tests = {
new Test(3, 0, 3, 5, 3_000, distributionTest(false)),
new Test(3, 0, 6, 12, 3_000, distributionTest(true)),
new Test(50, -10, 20, 70, 2_000, correctnessTest),
new Test(7, 3, 10, 42, 1_000_000, performanceTest),
new Test(20, 3, 10, 120, 100_000, performanceTest)
};
for (Test t : tests) {
System.out.println(t);
for (GeneratorFactory factory : factories) {
PartitionGenerator candidate = t.getGenerator(factory);
t.procedure.accept(candidate, t);
}
System.out.println();
}
}
}
这是来自 John McClane 的 PermutationPartitionGenerator 的算法,在本页的另一个答案中。它有两个阶段,即设置阶段和采样阶段,并在 [min
、max
] 中生成 n
个随机变量,总和为 sum
,其中数字为以随机顺序列出。
设置阶段:首先,使用以下公式构建解决方案 table(t(y, x)
其中 y
在 [0,n
] 和 x
在 [0, sum - n * min
]):
- t(0, j) = 1 如果 j == 0; 0 否则
- t(i, j) = t(i-1, j) + t(i-1, j-1) + ... + t(i-1, j-(max-min))
这里,t(y, x) 存储 y
个数字(在适当范围内)的总和等于 x
的相对概率。这个概率是相对于所有具有相同y
.
的t(y, x)
采样阶段:这里我们生成 n
个数字的样本。将 s
设置为 sum - n * min
,然后对于每个位置 i
,从 n - 1
开始并向后计算到 0:
- 设置
v
为[0,t(i+1,s)]中的一个均匀随机整数。
- 将
r
设置为 min
。
- 从
v
中减去 t(i, s)。
- 当
v
保持为 0 或更大时,从 v
中减去 t(i, s-1),将 r
加 1,然后从 s
中减去 1 .
- 样本中位置
i
的数字设置为r
。
编辑:
看来,通过对上述算法进行微不足道的更改,可以让每个随机变量使用单独的范围,而不是对所有变量使用相同的范围:
位置i
∈[0,n
)的每个随机变量都有一个最小值min(i)和一个最大值max(i)。
令adjsum
= sum
- ∑min(i).
设置阶段:首先,使用以下公式构建解决方案 table(t(y, x)
其中 y
在 [0,n
] 和 x
在 [0, adjsum
]):
- t(0, j) = 1 如果 j == 0; 0 否则
- t(i, j) = t(i-1, j) + t(i-1, j-1) + ... + t(i-1, j-(最大(i-1)-最小(i-1)))
采样阶段与之前完全相同,只是我们将 s
设置为 adjsum
(而不是 sum - n * min
)并将 r
设置为 min(i )(而不是 min
)。
编辑:
对于 John McClane 的 CombinationPartitionGenerator,设置和采样阶段如下。
设置阶段:首先,使用以下公式构建解决方案 table(t(z, y, x)
其中 z
在 [0,n
],y
在[0,max - min
],x
在[0,sum - n * min
]):
- t(0, j, k) = 1 如果 k == 0; 0 否则
- t(i, 0, k) = t(i - 1, 0, k)
- t(i, j, k) = t(i, j-1, k) + t(i - 1, j, k - j)
采样阶段:这里我们生成 n
个数字的样本。将 s
设置为 sum - n * min
并将 mrange
设置为 max - min
,然后对于每个位置 i
,从 n - 1
开始并向后计算到 0:
- 设置
v
为[0, t(i+1, mrange, s)]中的一个均匀随机整数。
- 将
mrange
设置为最小值(mrange
, s
)
- 从
s
中减去 mrange
。
- 将
r
设置为 min + mrange
。
- 从
v
. 中减去 t(i
, mrange
, s
)
- 当
v
保持为 0 或更大时,将 s
加 1,从 r
中减去 1,从 mrange
中减去 1,然后减去 t(i
, mrange
, s
) 来自 v
.
- 样本中位置
i
的数字设置为r
。
正如 OP 指出的那样,有效取消排名的能力非常强大。如果我们能够这样做,则可以通过三个步骤生成分区的均匀分布(重申 OP 在问题中提出的内容):
- 计算长度为N[=106的分区总数M =] 的数字
sum
使得部分在 [min
, max
]. 范围内
- 从
[1, M]
. 生成均匀分布的整数
- 将步骤 2 中的每个整数取消排序到其各自的分区中。
下面,我们只关注生成第nth分区,因为生成均匀分布的信息量很大给定范围内的整数。这是一个简单的 C++
unranking 算法,应该很容易翻译成其他语言(N.B。我还没有想出如何取消组合案例(即顺序很重要))。
std::vector<int> unRank(int n, int m, int myMax, int nth) {
std::vector<int> z(m, 0);
int count = 0;
int j = 0;
for (int i = 0; i < z.size(); ++i) {
int temp = pCount(n - 1, m - 1, myMax);
for (int r = n - m, k = myMax - 1;
(count + temp) < nth && r > 0 && k; r -= m, --k) {
count += temp;
n = r;
myMax = k;
++j;
temp = pCount(n - 1, m - 1, myMax);
}
--m;
--n;
z[i] = j;
}
return z;
}
主力 pCount
函数由:
int pCount(int n, int m, int myMax) {
if (myMax * m < n) return 0;
if (myMax * m == n) return 1;
if (m < 2) return m;
if (n < m) return 0;
if (n <= m + 1) return 1;
int niter = n / m;
int count = 0;
for (; niter--; n -= m, --myMax) {
count += pCount(n - 1, m - 1, myMax);
}
return count;
}
此功能基于用户@m69_snarky_and_unwelcoming 对 的出色回答。上面给出的是对简单算法(没有记忆的那个)的轻微修改。这可以很容易地修改以合并记忆以提高效率。我们暂时将其关闭并专注于未排名的部分。
unRank
的解释
我们首先注意到长度为 N 的分区有一个一对一的映射 sum
使得部分在 [min
, max
] 到长度为 N 的受限分区范围内sum - N * (min - 1)
的数字 [1
、max - (min - 1)
].
作为一个小例子,考虑 50
长度 4
的分区,使得 min = 10
和 max = 15
。这将与长度为 4
的 50 - 4 * (10 - 1) = 14
的限制分区具有相同的结构,最大部分等于 15 - (10 - 1) = 6
.
10 10 15 15 --->> 1 1 6 6
10 11 14 15 --->> 1 2 5 6
10 12 13 15 --->> 1 3 4 6
10 12 14 14 --->> 1 3 5 5
10 13 13 14 --->> 1 4 4 5
11 11 13 15 --->> 2 2 4 6
11 11 14 14 --->> 2 2 5 5
11 12 12 15 --->> 2 3 3 6
11 12 13 14 --->> 2 3 4 5
11 13 13 13 --->> 2 4 4 4
12 12 12 14 --->> 3 3 3 5
12 12 13 13 --->> 3 3 4 4
考虑到这一点,为了方便计数,如果您愿意,我们可以添加步骤 1a 将问题转换为 "unit" 情况。
现在,我们只是遇到了一个计数问题。正如 @m69 出色地展示的那样,通过将问题分解为更小的问题可以轻松实现分区计数。 @m69 提供的函数让我们完成了 90% 的工作,我们只需要弄清楚如何处理有上限的附加限制。这是我们得到的地方:
int pCount(int n, int m, int myMax) {
if (myMax * m < n) return 0;
if (myMax * m == n) return 1;
我们还必须记住,myMax
会随着我们的前进而减少。如果我们查看上面的 6th 分区,这是有道理的:
2 2 4 6
为了从现在开始计算分区的数量,我们必须继续将翻译应用于 "unit" 的情况。这看起来像:
1 1 3 5
之前的步骤,我们有最大值6
,现在我们只考虑最大值5
。
考虑到这一点,取消分区的排名与取消标准排列或组合的排名没有什么不同。我们必须能够计算给定部分中的分区数。比如统计上面以10
开头的分区数,我们只需要去掉第一列的10
即可:
10 10 15 15
10 11 14 15
10 12 13 15
10 12 14 14
10 13 13 14
10 15 15
11 14 15
12 13 15
12 14 14
13 13 14
翻译成单位大小写:
1 6 6
2 5 6
3 4 6
3 5 5
4 4 5
并调用 pCount
:
pCount(13, 3, 6) = 5
给定一个要取消排序的随机整数,我们继续计算越来越小的分区的数量(就像我们上面所做的那样),直到我们填满我们的索引向量。
示例
给定 min = 3
、max = 10
、n = 7
和 sum = 42
,这是一个生成 20 个随机分区的 ideone 演示。输出如下:
42: 3 3 6 7 7 8 8
123: 4 4 6 6 6 7 9
2: 3 3 3 4 9 10 10
125: 4 4 6 6 7 7 8
104: 4 4 4 6 6 8 10
74: 3 4 6 7 7 7 8
47: 3 4 4 5 6 10 10
146: 5 5 5 5 6 7 9
70: 3 4 6 6 6 7 10
134: 4 5 5 6 6 7 9
136: 4 5 5 6 7 7 8
81: 3 5 5 5 8 8 8
122: 4 4 6 6 6 6 10
112: 4 4 5 5 6 8 10
147: 5 5 5 5 6 8 8
142: 4 6 6 6 6 7 7
37: 3 3 6 6 6 9 9
67: 3 4 5 6 8 8 8
45: 3 4 4 4 8 9 10
44: 3 4 4 4 7 10 10
左边是字典索引,右边是未排序的分区。
如果在[l,x-1]范围内均匀生成0≤a≤1个随机值,在[x,h]范围内均匀生成1-a个随机值,则期望均值将是:
m = ((l+x-1)/2)*a + ((x+h)/2)*(1-a)
因此,如果您想要特定的 m,可以使用 a 和 x。
例如,如果您设置 x = m:a = (h-m)/(h-l+1)。
为确保不同组合的概率更接近均匀,请从上述等式的有效解集中随机选择 a 或 x。 (x 必须在 [l, h] 范围内并且应该是(接近)一个整数;N*a 也应该是(接近)一个整数。
我为 Python-numpy 实现了(未排序的)算法,每个随机数都有单独的范围 [min, max]。也许它对使用 Python 作为主要编程语言的人有用。
import numpy as np
def randint_sum_equal_to(sum_value: int,
n: int,
lower: (int, list) = 0,
upper: (int,list) = None):
# Control on input
if isinstance(lower, (list, np.ndarray)):
assert len(lower) == n
else:
lower = lower * np.ones(n)
if isinstance(upper, (list, np.ndarray)):
assert len(upper) == n
elif upper is None:
upper = sum_value * np.ones(n)
else:
upper = upper * np.ones(n)
# Trivial solutions
if np.sum(upper) < sum_value:
raise ValueError('No solution can be found: sum(upper_bound) < sum_value')
elif np.sum(lower) > sum_value:
raise ValueError('No solution can be found: sum(lower_bound) > sum_value')
elif np.sum(upper) == sum_value:
return upper
elif np.sum(lower) == sum_value:
return lower
# Setup phase
# I generate the table t(y,x) storing the relative probability that the sum of y numbers
# (in the appropriate range) is equal x.
t = np.zeros((n + 1, sum_value))
t[0, 0] = 1
for i in np.arange(1, n + 1):
# Build the k indexes which are taken for each j following k from 0 to min(u(i-1)-l(i-1), j).
# This can be obtained creating a repetition matrix of from t[i] multiplied by the triangular matrix
# tri_mask and then sum each row
tri_mask = np.tri(sum_value, k=0) - np.tri(sum_value, k=-(upper[i-1] - lower[i-1]))
t[i] = np.sum(np.repeat(t[i-1][np.newaxis], sum_value, 0)*tri_mask, axis=1)
# Sampling phase
values = np.zeros(n)
s = (sum_value - np.sum(lower)).astype(int)
for i in np.arange(n)[::-1]:
# The basic algorithm is the one commented:
# v = np.round(np.random.rand() * t[i+1, s])
# r = lower[i]
# v -= t[i, s]
# while (v >= 0) and (s > 0):
# s -= 1
# v -= t[i, s]
# r += 1
# values[i] = r
# ---------------------------------------------------- #
# To speed up the convergence I use some numpy tricks.
# The idea is the same of the Setup phase:
# - I build a repeat matrix of t[i, s:1];
# - I take only the lower triangular part, multiplying by a np.tri(s)
# - I sum over rows, so each element of sum_t contains the cumulative sum of t[i, s - k]
# - I subtract v - sum_t and count the element greater of equal zero,
# which are used to set the output and update s
v = np.round(np.random.rand() * t[i+1, s])
values[i] = lower[i]
sum_t = np.sum(np.repeat(t[i, np.arange(1, s + 1)[::-1]][np.newaxis], s, 0) * np.tri(s), axis=1)
vt_difference_nonzero = np.sum(np.repeat(v, s) - sum_t >= 0)
values[i] += vt_difference_nonzero
s -= vt_difference_nonzero
return values.astype(int)
有没有一种有效的方法来生成 N 个整数的随机组合,使得——
- 每个整数都在区间[
min
,max
], - 整数的总和为
sum
, - 整数可以以任何顺序出现(例如,随机顺序),并且
- 从满足其他要求的所有组合中随机统一选择组合?
是否有类似的随机组合算法,其中整数必须按其值的排序顺序(而不是任何顺序)出现?
(如果sum = N * mean
,选择一个均值为mean
的合适组合是一个特例。这个问题相当于生成一个sum
的均匀随机划分成N个部分每个都在区间 [min
、max
] 中,并以任意顺序出现或按它们的值排序,视情况而定。)
我知道对于以随机顺序出现的组合,可以通过以下方式解决此问题(编辑 [4 月 27 日]:算法已修改。):
如果
N * max < sum
或者N * min > sum
,无解如果
N * max == sum
,只有一个解,所有N
个数都等于max
。如果N * min == sum
,只有一个解,所有N
个数都等于min
。Use the algorithm Smith and Tromble ("Sampling from the Unit Simplex", 2004) 给出的生成 N 个随机非负整数的总和
sum - N * min
.将
min
添加到以这种方式生成的每个数字。如果任何数字大于
max
,则转到步骤3。
但是,如果 max
远小于 sum
,则此算法会很慢。例如,根据我的测试(上面涉及 mean
的特殊情况的实现),该算法平均拒绝 -
- 如果
N = 7, min = 3, max = 10, sum = 42
,大约 1.6 个样本,但是 - 大约 30.6 个样本,如果
N = 20, min = 3, max = 10, sum = 120
。
有没有办法修改此算法以使其对大 N 有效,同时仍满足上述要求?
编辑:
作为评论中建议的替代方案,产生有效随机组合(满足除最后一个要求之外的所有要求)的有效方法是:
- 计算
X
,给定sum
、min
和max
的可能有效组合数。 - 选择
Y
,[0, X)
中的均匀随机整数。 - 将 ("unrank")
Y
转换为有效组合。
但是,有没有计算有效组合(或排列)数的公式,有没有办法将整数转换为有效组合? [编辑(4 月 28 日):排列相同而不是组合]。
编辑(4 月 27 日):
在阅读了 Devroye 的 Non-Uniform Random Variate Generation (1986) 之后,我可以确认这是一个生成随机分区的问题。此外,第 661 页的练习 2(尤其是 E 部分)与此问题相关。
编辑(4 月 28 日):
事实证明,我给出的算法是统一的,其中涉及的整数以 随机顺序 给出,而不是 按其值排序的顺序 。由于这两个问题都是普遍感兴趣的,所以我修改了这个问题以寻求这两个问题的规范答案。
以下Ruby代码可用于验证潜在的一致性解决方案(其中algorithm(...)
是候选算法):
combos={}
permus={}
mn=0
mx=6
sum=12
for x in mn..mx
for y in mn..mx
for z in mn..mx
if x+y+z==sum
permus[[x,y,z]]=0
end
if x+y+z==sum and x<=y and y<=z
combos[[x,y,z]]=0
end
end
end
end
3000.times {|x|
f=algorithm(3,sum,mn,mx)
combos[f.sort]+=1
permus[f]+=1
}
p combos
p permus
编辑(4 月 29 日):重新添加 Ruby 当前实现的代码。
Ruby中给出了以下代码示例,但我的问题与编程语言无关:
def posintwithsum(n, total)
raise if n <= 0 or total <=0
ls = [0]
ret = []
while ls.length < n
c = 1+rand(total-1)
found = false
for j in 1...ls.length
if ls[j] == c
found = true
break
end
end
if found == false;ls.push(c);end
end
ls.sort!
ls.push(total)
for i in 1...ls.length
ret.push(ls[i] - ls[i - 1])
end
return ret
end
def integersWithSum(n, total)
raise if n <= 0 or total <=0
ret = posintwithsum(n, total + n)
for i in 0...ret.length
ret[i] = ret[i] - 1
end
return ret
end
# Generate 100 valid samples
mn=3
mx=10
sum=42
n=7
100.times {
while true
pp=integersWithSum(n,sum-n*mn).map{|x| x+mn }
if !pp.find{|x| x>mx }
p pp; break # Output the sample and break
end
end
}
我还没有测试过这个,所以它不是一个真正的答案,只是尝试一下太长而无法放入评论中。从满足前两个条件的数组开始并使用它,使其仍然满足前两个条件,但更加随机。
如果平均值是一个整数,那么你的初始数组可以是 [4, 4, 4, ... 4] 或者 [3, 4, 5, 3, 4, 5, ... 5, 8, 0] 或类似的简单内容。对于 4.5 的平均值,请尝试 [4, 5, 4, 5, ... 4, 5].
接下来在数组中选择一对数字,num1
和 num2
。可能第一个数字应该按顺序取,就像 Fisher-Yates 洗牌一样,第二个数字应该随机选择。按顺序取第一个号码可确保每个号码至少被选中一次。
现在计算max-num1
和num2-min
。这些是从两个数字到 max
和 min
边界的距离。将 limit
设置为两个距离中较小的一个。这是允许的最大更改,不会使一个或另一个数字超出允许的限制。如果 limit
为零,则跳过这对。
在 [1, limit
] 范围内随机选择一个整数:称其为 change
。我从可选范围中省略了 0,因为它没有任何效果。测试可能会表明通过包含它可以获得更好的随机性;我不确定。
现在设置 num1 <- num1 + change
和 num2 <- num2 - change
。这不会影响平均值,并且数组的所有元素仍在要求的边界内。
您将需要 运行 至少遍历整个数组一次。测试应该显示您是否需要 运行 多次通过它来获得足够随机的东西。
预计到达时间:包括伪代码
// Set up the array.
resultAry <- new array size N
for (i <- 0 to N-1)
// More complex initial setup schemes are possible here.
resultAry[i] <- mean
rof
// Munge the array entries.
for (ix1 <- 0 to N-1) // ix1 steps through the array in order.
// Pick second entry different from first.
repeat
ix2 <- random(0, N-1)
until (ix2 != ix1)
// Calculate size of allowed change.
hiLimit <- max - resultAry[ix1]
loLimit <- resultAry[ix2] - min
limit <- minimum(hiLimit, loLimit)
if (limit == 0)
// No change possible so skip.
continue loop with next ix1
fi
// Change the two entries keeping same mean.
change <- random(1, limit) // Or (0, limit) possibly.
resultAry[ix1] <- resultAry[ix1] + change
resultAry[ix2] <- resultAry[ix2] - change
rof
// Check array has been sufficiently munged.
if (resultAry not random enough)
munge the array again
fi
这是我在 Java 中的解决方案。它功能齐全,包含两个生成器:PermutationPartitionGenerator
用于未排序的分区,CombinationPartitionGenerator
用于排序的分区。您的生成器也在 class SmithTromblePartitionGenerator
中实现以供比较。 class SequentialEnumerator
按顺序枚举所有可能的分区(未排序或排序,取决于参数)。我已经为所有这些生成器添加了全面的测试(包括您的测试用例)。
在大多数情况下,实现是不言自明的。有什么问题,过几天我会回复的。
import java.util.Random;
import java.util.function.Supplier;
public abstract class PartitionGenerator implements Supplier<int[]>{
public static final Random rand = new Random();
protected final int numberCount;
protected final int min;
protected final int range;
protected final int sum; // shifted sum
protected final boolean sorted;
protected PartitionGenerator(int numberCount, int min, int max, int sum, boolean sorted) {
if (numberCount <= 0)
throw new IllegalArgumentException("Number count should be positive");
this.numberCount = numberCount;
this.min = min;
range = max - min;
if (range < 0)
throw new IllegalArgumentException("min > max");
sum -= numberCount * min;
if (sum < 0)
throw new IllegalArgumentException("Sum is too small");
if (numberCount * range < sum)
throw new IllegalArgumentException("Sum is too large");
this.sum = sum;
this.sorted = sorted;
}
// Whether this generator returns sorted arrays (i.e. combinations)
public final boolean isSorted() {
return sorted;
}
public interface GeneratorFactory {
PartitionGenerator create(int numberCount, int min, int max, int sum);
}
}
import java.math.BigInteger;
// Permutations with repetition (i.e. unsorted vectors) with given sum
public class PermutationPartitionGenerator extends PartitionGenerator {
private final double[][] distributionTable;
public PermutationPartitionGenerator(int numberCount, int min, int max, int sum) {
super(numberCount, min, max, sum, false);
distributionTable = calculateSolutionCountTable();
}
private double[][] calculateSolutionCountTable() {
double[][] table = new double[numberCount + 1][sum + 1];
BigInteger[] a = new BigInteger[sum + 1];
BigInteger[] b = new BigInteger[sum + 1];
for (int i = 1; i <= sum; i++)
a[i] = BigInteger.ZERO;
a[0] = BigInteger.ONE;
table[0][0] = 1.0;
for (int n = 1; n <= numberCount; n++) {
double[] t = table[n];
for (int s = 0; s <= sum; s++) {
BigInteger z = BigInteger.ZERO;
for (int i = Math.max(0, s - range); i <= s; i++)
z = z.add(a[i]);
b[s] = z;
t[s] = z.doubleValue();
}
// swap a and b
BigInteger[] c = b;
b = a;
a = c;
}
return table;
}
@Override
public int[] get() {
int[] p = new int[numberCount];
int s = sum; // current sum
for (int i = numberCount - 1; i >= 0; i--) {
double t = rand.nextDouble() * distributionTable[i + 1][s];
double[] tableRow = distributionTable[i];
int oldSum = s;
// lowerBound is introduced only for safety, it shouldn't be crossed
int lowerBound = s - range;
if (lowerBound < 0)
lowerBound = 0;
s++;
do
t -= tableRow[--s];
// s can be equal to lowerBound here with t > 0 only due to imprecise subtraction
while (t > 0 && s > lowerBound);
p[i] = min + (oldSum - s);
}
assert s == 0;
return p;
}
public static final GeneratorFactory factory = (numberCount, min, max,sum) ->
new PermutationPartitionGenerator(numberCount, min, max, sum);
}
import java.math.BigInteger;
// Combinations with repetition (i.e. sorted vectors) with given sum
public class CombinationPartitionGenerator extends PartitionGenerator {
private final double[][][] distributionTable;
public CombinationPartitionGenerator(int numberCount, int min, int max, int sum) {
super(numberCount, min, max, sum, true);
distributionTable = calculateSolutionCountTable();
}
private double[][][] calculateSolutionCountTable() {
double[][][] table = new double[numberCount + 1][range + 1][sum + 1];
BigInteger[][] a = new BigInteger[range + 1][sum + 1];
BigInteger[][] b = new BigInteger[range + 1][sum + 1];
double[][] t = table[0];
for (int m = 0; m <= range; m++) {
a[m][0] = BigInteger.ONE;
t[m][0] = 1.0;
for (int s = 1; s <= sum; s++) {
a[m][s] = BigInteger.ZERO;
t[m][s] = 0.0;
}
}
for (int n = 1; n <= numberCount; n++) {
t = table[n];
for (int m = 0; m <= range; m++)
for (int s = 0; s <= sum; s++) {
BigInteger z;
if (m == 0)
z = a[0][s];
else {
z = b[m - 1][s];
if (m <= s)
z = z.add(a[m][s - m]);
}
b[m][s] = z;
t[m][s] = z.doubleValue();
}
// swap a and b
BigInteger[][] c = b;
b = a;
a = c;
}
return table;
}
@Override
public int[] get() {
int[] p = new int[numberCount];
int m = range; // current max
int s = sum; // current sum
for (int i = numberCount - 1; i >= 0; i--) {
double t = rand.nextDouble() * distributionTable[i + 1][m][s];
double[][] tableCut = distributionTable[i];
if (s < m)
m = s;
s -= m;
while (true) {
t -= tableCut[m][s];
// m can be 0 here with t > 0 only due to imprecise subtraction
if (t <= 0 || m == 0)
break;
m--;
s++;
}
p[i] = min + m;
}
assert s == 0;
return p;
}
public static final GeneratorFactory factory = (numberCount, min, max, sum) ->
new CombinationPartitionGenerator(numberCount, min, max, sum);
}
import java.util.*;
public class SmithTromblePartitionGenerator extends PartitionGenerator {
public SmithTromblePartitionGenerator(int numberCount, int min, int max, int sum) {
super(numberCount, min, max, sum, false);
}
@Override
public int[] get() {
List<Integer> ls = new ArrayList<>(numberCount + 1);
int[] ret = new int[numberCount];
int increasedSum = sum + numberCount;
while (true) {
ls.add(0);
while (ls.size() < numberCount) {
int c = 1 + rand.nextInt(increasedSum - 1);
if (!ls.contains(c))
ls.add(c);
}
Collections.sort(ls);
ls.add(increasedSum);
boolean good = true;
for (int i = 0; i < numberCount; i++) {
int x = ls.get(i + 1) - ls.get(i) - 1;
if (x > range) {
good = false;
break;
}
ret[i] = x;
}
if (good) {
for (int i = 0; i < numberCount; i++)
ret[i] += min;
return ret;
}
ls.clear();
}
}
public static final GeneratorFactory factory = (numberCount, min, max, sum) ->
new SmithTromblePartitionGenerator(numberCount, min, max, sum);
}
import java.util.Arrays;
// Enumerates all partitions with given parameters
public class SequentialEnumerator extends PartitionGenerator {
private final int max;
private final int[] p;
private boolean finished;
public SequentialEnumerator(int numberCount, int min, int max, int sum, boolean sorted) {
super(numberCount, min, max, sum, sorted);
this.max = max;
p = new int[numberCount];
startOver();
}
private void startOver() {
finished = false;
int unshiftedSum = sum + numberCount * min;
fillMinimal(0, Math.max(min, unshiftedSum - (numberCount - 1) * max), unshiftedSum);
}
private void fillMinimal(int beginIndex, int minValue, int fillSum) {
int fillRange = max - minValue;
if (fillRange == 0)
Arrays.fill(p, beginIndex, numberCount, max);
else {
int fillCount = numberCount - beginIndex;
fillSum -= fillCount * minValue;
int maxCount = fillSum / fillRange;
int maxStartIndex = numberCount - maxCount;
Arrays.fill(p, maxStartIndex, numberCount, max);
fillSum -= maxCount * fillRange;
Arrays.fill(p, beginIndex, maxStartIndex, minValue);
if (fillSum != 0)
p[maxStartIndex - 1] = minValue + fillSum;
}
}
@Override
public int[] get() { // returns null when there is no more partition, then starts over
if (finished) {
startOver();
return null;
}
int[] pCopy = p.clone();
if (numberCount > 1) {
int i = numberCount;
int s = p[--i];
while (i > 0) {
int x = p[--i];
if (x == max) {
s += x;
continue;
}
x++;
s--;
int minRest = sorted ? x : min;
if (s < minRest * (numberCount - i - 1)) {
s += x;
continue;
}
p[i++]++;
fillMinimal(i, minRest, s);
return pCopy;
}
}
finished = true;
return pCopy;
}
public static final GeneratorFactory permutationFactory = (numberCount, min, max, sum) ->
new SequentialEnumerator(numberCount, min, max, sum, false);
public static final GeneratorFactory combinationFactory = (numberCount, min, max, sum) ->
new SequentialEnumerator(numberCount, min, max, sum, true);
}
import java.util.*;
import java.util.function.BiConsumer;
import PartitionGenerator.GeneratorFactory;
public class Test {
private final int numberCount;
private final int min;
private final int max;
private final int sum;
private final int repeatCount;
private final BiConsumer<PartitionGenerator, Test> procedure;
public Test(int numberCount, int min, int max, int sum, int repeatCount,
BiConsumer<PartitionGenerator, Test> procedure) {
this.numberCount = numberCount;
this.min = min;
this.max = max;
this.sum = sum;
this.repeatCount = repeatCount;
this.procedure = procedure;
}
@Override
public String toString() {
return String.format("=== %d numbers from [%d, %d] with sum %d, %d iterations ===",
numberCount, min, max, sum, repeatCount);
}
private static class GeneratedVector {
final int[] v;
GeneratedVector(int[] vect) {
v = vect;
}
@Override
public int hashCode() {
return Arrays.hashCode(v);
}
@Override
public boolean equals(Object obj) {
if (this == obj)
return true;
return Arrays.equals(v, ((GeneratedVector)obj).v);
}
@Override
public String toString() {
return Arrays.toString(v);
}
}
private static final Comparator<Map.Entry<GeneratedVector, Integer>> lexicographical = (e1, e2) -> {
int[] v1 = e1.getKey().v;
int[] v2 = e2.getKey().v;
int len = v1.length;
int d = len - v2.length;
if (d != 0)
return d;
for (int i = 0; i < len; i++) {
d = v1[i] - v2[i];
if (d != 0)
return d;
}
return 0;
};
private static final Comparator<Map.Entry<GeneratedVector, Integer>> byCount =
Comparator.<Map.Entry<GeneratedVector, Integer>>comparingInt(Map.Entry::getValue)
.thenComparing(lexicographical);
public static int SHOW_MISSING_LIMIT = 10;
private static void checkMissingPartitions(Map<GeneratedVector, Integer> map, PartitionGenerator reference) {
int missingCount = 0;
while (true) {
int[] v = reference.get();
if (v == null)
break;
GeneratedVector gv = new GeneratedVector(v);
if (!map.containsKey(gv)) {
if (missingCount == 0)
System.out.println(" Missing:");
if (++missingCount > SHOW_MISSING_LIMIT) {
System.out.println(" . . .");
break;
}
System.out.println(gv);
}
}
}
public static final BiConsumer<PartitionGenerator, Test> distributionTest(boolean sortByCount) {
return (PartitionGenerator gen, Test test) -> {
System.out.print("\n" + getName(gen) + "\n\n");
Map<GeneratedVector, Integer> combos = new HashMap<>();
// There's no point of checking permus for sorted generators
// because they are the same as combos for them
Map<GeneratedVector, Integer> permus = gen.isSorted() ? null : new HashMap<>();
for (int i = 0; i < test.repeatCount; i++) {
int[] v = gen.get();
if (v == null && gen instanceof SequentialEnumerator)
break;
if (permus != null) {
permus.merge(new GeneratedVector(v), 1, Integer::sum);
v = v.clone();
Arrays.sort(v);
}
combos.merge(new GeneratedVector(v), 1, Integer::sum);
}
Set<Map.Entry<GeneratedVector, Integer>> sortedEntries = new TreeSet<>(
sortByCount ? byCount : lexicographical);
System.out.println("Combos" + (gen.isSorted() ? ":" : " (don't have to be uniform):"));
sortedEntries.addAll(combos.entrySet());
for (Map.Entry<GeneratedVector, Integer> e : sortedEntries)
System.out.println(e);
checkMissingPartitions(combos, test.getGenerator(SequentialEnumerator.combinationFactory));
if (permus != null) {
System.out.println("\nPermus:");
sortedEntries.clear();
sortedEntries.addAll(permus.entrySet());
for (Map.Entry<GeneratedVector, Integer> e : sortedEntries)
System.out.println(e);
checkMissingPartitions(permus, test.getGenerator(SequentialEnumerator.permutationFactory));
}
};
}
public static final BiConsumer<PartitionGenerator, Test> correctnessTest =
(PartitionGenerator gen, Test test) -> {
String genName = getName(gen);
for (int i = 0; i < test.repeatCount; i++) {
int[] v = gen.get();
if (v == null && gen instanceof SequentialEnumerator)
v = gen.get();
if (v.length != test.numberCount)
throw new RuntimeException(genName + ": array of wrong length");
int s = 0;
if (gen.isSorted()) {
if (v[0] < test.min || v[v.length - 1] > test.max)
throw new RuntimeException(genName + ": generated number is out of range");
int prev = test.min;
for (int x : v) {
if (x < prev)
throw new RuntimeException(genName + ": unsorted array");
s += x;
prev = x;
}
} else
for (int x : v) {
if (x < test.min || x > test.max)
throw new RuntimeException(genName + ": generated number is out of range");
s += x;
}
if (s != test.sum)
throw new RuntimeException(genName + ": wrong sum");
}
System.out.format("%30s : correctness test passed%n", genName);
};
public static final BiConsumer<PartitionGenerator, Test> performanceTest =
(PartitionGenerator gen, Test test) -> {
long time = System.nanoTime();
for (int i = 0; i < test.repeatCount; i++)
gen.get();
time = System.nanoTime() - time;
System.out.format("%30s : %8.3f s %10.0f ns/test%n", getName(gen), time * 1e-9, time * 1.0 / test.repeatCount);
};
public PartitionGenerator getGenerator(GeneratorFactory factory) {
return factory.create(numberCount, min, max, sum);
}
public static String getName(PartitionGenerator gen) {
String name = gen.getClass().getSimpleName();
if (gen instanceof SequentialEnumerator)
return (gen.isSorted() ? "Sorted " : "Unsorted ") + name;
else
return name;
}
public static GeneratorFactory[] factories = { SmithTromblePartitionGenerator.factory,
PermutationPartitionGenerator.factory, CombinationPartitionGenerator.factory,
SequentialEnumerator.permutationFactory, SequentialEnumerator.combinationFactory };
public static void main(String[] args) {
Test[] tests = {
new Test(3, 0, 3, 5, 3_000, distributionTest(false)),
new Test(3, 0, 6, 12, 3_000, distributionTest(true)),
new Test(50, -10, 20, 70, 2_000, correctnessTest),
new Test(7, 3, 10, 42, 1_000_000, performanceTest),
new Test(20, 3, 10, 120, 100_000, performanceTest)
};
for (Test t : tests) {
System.out.println(t);
for (GeneratorFactory factory : factories) {
PartitionGenerator candidate = t.getGenerator(factory);
t.procedure.accept(candidate, t);
}
System.out.println();
}
}
}
这是来自 John McClane 的 PermutationPartitionGenerator 的算法,在本页的另一个答案中。它有两个阶段,即设置阶段和采样阶段,并在 [min
、max
] 中生成 n
个随机变量,总和为 sum
,其中数字为以随机顺序列出。
设置阶段:首先,使用以下公式构建解决方案 table(t(y, x)
其中 y
在 [0,n
] 和 x
在 [0, sum - n * min
]):
- t(0, j) = 1 如果 j == 0; 0 否则
- t(i, j) = t(i-1, j) + t(i-1, j-1) + ... + t(i-1, j-(max-min))
这里,t(y, x) 存储 y
个数字(在适当范围内)的总和等于 x
的相对概率。这个概率是相对于所有具有相同y
.
采样阶段:这里我们生成 n
个数字的样本。将 s
设置为 sum - n * min
,然后对于每个位置 i
,从 n - 1
开始并向后计算到 0:
- 设置
v
为[0,t(i+1,s)]中的一个均匀随机整数。 - 将
r
设置为min
。 - 从
v
中减去 t(i, s)。 - 当
v
保持为 0 或更大时,从v
中减去 t(i, s-1),将r
加 1,然后从s
中减去 1 . - 样本中位置
i
的数字设置为r
。
编辑:
看来,通过对上述算法进行微不足道的更改,可以让每个随机变量使用单独的范围,而不是对所有变量使用相同的范围:
位置i
∈[0,n
)的每个随机变量都有一个最小值min(i)和一个最大值max(i)。
令adjsum
= sum
- ∑min(i).
设置阶段:首先,使用以下公式构建解决方案 table(t(y, x)
其中 y
在 [0,n
] 和 x
在 [0, adjsum
]):
- t(0, j) = 1 如果 j == 0; 0 否则
- t(i, j) = t(i-1, j) + t(i-1, j-1) + ... + t(i-1, j-(最大(i-1)-最小(i-1)))
采样阶段与之前完全相同,只是我们将 s
设置为 adjsum
(而不是 sum - n * min
)并将 r
设置为 min(i )(而不是 min
)。
编辑:
对于 John McClane 的 CombinationPartitionGenerator,设置和采样阶段如下。
设置阶段:首先,使用以下公式构建解决方案 table(t(z, y, x)
其中 z
在 [0,n
],y
在[0,max - min
],x
在[0,sum - n * min
]):
- t(0, j, k) = 1 如果 k == 0; 0 否则
- t(i, 0, k) = t(i - 1, 0, k)
- t(i, j, k) = t(i, j-1, k) + t(i - 1, j, k - j)
采样阶段:这里我们生成 n
个数字的样本。将 s
设置为 sum - n * min
并将 mrange
设置为 max - min
,然后对于每个位置 i
,从 n - 1
开始并向后计算到 0:
- 设置
v
为[0, t(i+1, mrange, s)]中的一个均匀随机整数。 - 将
mrange
设置为最小值(mrange
,s
) - 从
s
中减去mrange
。 - 将
r
设置为min + mrange
。 - 从
v
. 中减去 t( - 当
v
保持为 0 或更大时,将s
加 1,从r
中减去 1,从mrange
中减去 1,然后减去 t(i
,mrange
,s
) 来自v
. - 样本中位置
i
的数字设置为r
。
i
, mrange
, s
)
正如 OP 指出的那样,有效取消排名的能力非常强大。如果我们能够这样做,则可以通过三个步骤生成分区的均匀分布(重申 OP 在问题中提出的内容):
- 计算长度为N[=106的分区总数M =] 的数字
sum
使得部分在 [min
,max
]. 范围内
- 从
[1, M]
. 生成均匀分布的整数
- 将步骤 2 中的每个整数取消排序到其各自的分区中。
下面,我们只关注生成第nth分区,因为生成均匀分布的信息量很大给定范围内的整数。这是一个简单的 C++
unranking 算法,应该很容易翻译成其他语言(N.B。我还没有想出如何取消组合案例(即顺序很重要))。
std::vector<int> unRank(int n, int m, int myMax, int nth) {
std::vector<int> z(m, 0);
int count = 0;
int j = 0;
for (int i = 0; i < z.size(); ++i) {
int temp = pCount(n - 1, m - 1, myMax);
for (int r = n - m, k = myMax - 1;
(count + temp) < nth && r > 0 && k; r -= m, --k) {
count += temp;
n = r;
myMax = k;
++j;
temp = pCount(n - 1, m - 1, myMax);
}
--m;
--n;
z[i] = j;
}
return z;
}
主力 pCount
函数由:
int pCount(int n, int m, int myMax) {
if (myMax * m < n) return 0;
if (myMax * m == n) return 1;
if (m < 2) return m;
if (n < m) return 0;
if (n <= m + 1) return 1;
int niter = n / m;
int count = 0;
for (; niter--; n -= m, --myMax) {
count += pCount(n - 1, m - 1, myMax);
}
return count;
}
此功能基于用户@m69_snarky_and_unwelcoming 对
unRank
的解释
我们首先注意到长度为 N 的分区有一个一对一的映射 sum
使得部分在 [min
, max
] 到长度为 N 的受限分区范围内sum - N * (min - 1)
的数字 [1
、max - (min - 1)
].
作为一个小例子,考虑 50
长度 4
的分区,使得 min = 10
和 max = 15
。这将与长度为 4
的 50 - 4 * (10 - 1) = 14
的限制分区具有相同的结构,最大部分等于 15 - (10 - 1) = 6
.
10 10 15 15 --->> 1 1 6 6
10 11 14 15 --->> 1 2 5 6
10 12 13 15 --->> 1 3 4 6
10 12 14 14 --->> 1 3 5 5
10 13 13 14 --->> 1 4 4 5
11 11 13 15 --->> 2 2 4 6
11 11 14 14 --->> 2 2 5 5
11 12 12 15 --->> 2 3 3 6
11 12 13 14 --->> 2 3 4 5
11 13 13 13 --->> 2 4 4 4
12 12 12 14 --->> 3 3 3 5
12 12 13 13 --->> 3 3 4 4
考虑到这一点,为了方便计数,如果您愿意,我们可以添加步骤 1a 将问题转换为 "unit" 情况。
现在,我们只是遇到了一个计数问题。正如 @m69 出色地展示的那样,通过将问题分解为更小的问题可以轻松实现分区计数。 @m69 提供的函数让我们完成了 90% 的工作,我们只需要弄清楚如何处理有上限的附加限制。这是我们得到的地方:
int pCount(int n, int m, int myMax) {
if (myMax * m < n) return 0;
if (myMax * m == n) return 1;
我们还必须记住,myMax
会随着我们的前进而减少。如果我们查看上面的 6th 分区,这是有道理的:
2 2 4 6
为了从现在开始计算分区的数量,我们必须继续将翻译应用于 "unit" 的情况。这看起来像:
1 1 3 5
之前的步骤,我们有最大值6
,现在我们只考虑最大值5
。
考虑到这一点,取消分区的排名与取消标准排列或组合的排名没有什么不同。我们必须能够计算给定部分中的分区数。比如统计上面以10
开头的分区数,我们只需要去掉第一列的10
即可:
10 10 15 15
10 11 14 15
10 12 13 15
10 12 14 14
10 13 13 14
10 15 15
11 14 15
12 13 15
12 14 14
13 13 14
翻译成单位大小写:
1 6 6
2 5 6
3 4 6
3 5 5
4 4 5
并调用 pCount
:
pCount(13, 3, 6) = 5
给定一个要取消排序的随机整数,我们继续计算越来越小的分区的数量(就像我们上面所做的那样),直到我们填满我们的索引向量。
示例
给定 min = 3
、max = 10
、n = 7
和 sum = 42
,这是一个生成 20 个随机分区的 ideone 演示。输出如下:
42: 3 3 6 7 7 8 8
123: 4 4 6 6 6 7 9
2: 3 3 3 4 9 10 10
125: 4 4 6 6 7 7 8
104: 4 4 4 6 6 8 10
74: 3 4 6 7 7 7 8
47: 3 4 4 5 6 10 10
146: 5 5 5 5 6 7 9
70: 3 4 6 6 6 7 10
134: 4 5 5 6 6 7 9
136: 4 5 5 6 7 7 8
81: 3 5 5 5 8 8 8
122: 4 4 6 6 6 6 10
112: 4 4 5 5 6 8 10
147: 5 5 5 5 6 8 8
142: 4 6 6 6 6 7 7
37: 3 3 6 6 6 9 9
67: 3 4 5 6 8 8 8
45: 3 4 4 4 8 9 10
44: 3 4 4 4 7 10 10
左边是字典索引,右边是未排序的分区。
如果在[l,x-1]范围内均匀生成0≤a≤1个随机值,在[x,h]范围内均匀生成1-a个随机值,则期望均值将是:
m = ((l+x-1)/2)*a + ((x+h)/2)*(1-a)
因此,如果您想要特定的 m,可以使用 a 和 x。
例如,如果您设置 x = m:a = (h-m)/(h-l+1)。
为确保不同组合的概率更接近均匀,请从上述等式的有效解集中随机选择 a 或 x。 (x 必须在 [l, h] 范围内并且应该是(接近)一个整数;N*a 也应该是(接近)一个整数。
我为 Python-numpy 实现了(未排序的)算法,每个随机数都有单独的范围 [min, max]。也许它对使用 Python 作为主要编程语言的人有用。
import numpy as np
def randint_sum_equal_to(sum_value: int,
n: int,
lower: (int, list) = 0,
upper: (int,list) = None):
# Control on input
if isinstance(lower, (list, np.ndarray)):
assert len(lower) == n
else:
lower = lower * np.ones(n)
if isinstance(upper, (list, np.ndarray)):
assert len(upper) == n
elif upper is None:
upper = sum_value * np.ones(n)
else:
upper = upper * np.ones(n)
# Trivial solutions
if np.sum(upper) < sum_value:
raise ValueError('No solution can be found: sum(upper_bound) < sum_value')
elif np.sum(lower) > sum_value:
raise ValueError('No solution can be found: sum(lower_bound) > sum_value')
elif np.sum(upper) == sum_value:
return upper
elif np.sum(lower) == sum_value:
return lower
# Setup phase
# I generate the table t(y,x) storing the relative probability that the sum of y numbers
# (in the appropriate range) is equal x.
t = np.zeros((n + 1, sum_value))
t[0, 0] = 1
for i in np.arange(1, n + 1):
# Build the k indexes which are taken for each j following k from 0 to min(u(i-1)-l(i-1), j).
# This can be obtained creating a repetition matrix of from t[i] multiplied by the triangular matrix
# tri_mask and then sum each row
tri_mask = np.tri(sum_value, k=0) - np.tri(sum_value, k=-(upper[i-1] - lower[i-1]))
t[i] = np.sum(np.repeat(t[i-1][np.newaxis], sum_value, 0)*tri_mask, axis=1)
# Sampling phase
values = np.zeros(n)
s = (sum_value - np.sum(lower)).astype(int)
for i in np.arange(n)[::-1]:
# The basic algorithm is the one commented:
# v = np.round(np.random.rand() * t[i+1, s])
# r = lower[i]
# v -= t[i, s]
# while (v >= 0) and (s > 0):
# s -= 1
# v -= t[i, s]
# r += 1
# values[i] = r
# ---------------------------------------------------- #
# To speed up the convergence I use some numpy tricks.
# The idea is the same of the Setup phase:
# - I build a repeat matrix of t[i, s:1];
# - I take only the lower triangular part, multiplying by a np.tri(s)
# - I sum over rows, so each element of sum_t contains the cumulative sum of t[i, s - k]
# - I subtract v - sum_t and count the element greater of equal zero,
# which are used to set the output and update s
v = np.round(np.random.rand() * t[i+1, s])
values[i] = lower[i]
sum_t = np.sum(np.repeat(t[i, np.arange(1, s + 1)[::-1]][np.newaxis], s, 0) * np.tri(s), axis=1)
vt_difference_nonzero = np.sum(np.repeat(v, s) - sum_t >= 0)
values[i] += vt_difference_nonzero
s -= vt_difference_nonzero
return values.astype(int)