字符向量的随机样本,元素之间没有前缀

Random sample of character vector, without elements prefixing one another

考虑一个字符向量,pool,其元素是 (zero-padded) 个二进制数,最多 max_len 个数字。

max_len <- 4
pool <- unlist(lapply(seq_len(max_len), function(x) 
  do.call(paste0, expand.grid(rep(list(c('0', '1')), x)))))

pool
##  [1] "0"    "1"    "00"   "10"   "01"   "11"   "000"  "100"  "010"  "110" 
## [11] "001"  "101"  "011"  "111"  "0000" "1000" "0100" "1100" "0010" "1010"
## [21] "0110" "1110" "0001" "1001" "0101" "1101" "0011" "1011" "0111" "1111"

我想对这些元素进行 n 采样,条件是采样元素的 none 是任何其他采样元素的 prefixes(即,如果我们采样 1101,我们禁止 111110,而如果我们采样 1,我们禁止那些以 1 开头的元素,例如 1011100 等)。

以下是我使用while的尝试,当然当n很大时(或接近2^max_len时这很慢)。

set.seed(1)
n <- 10
chosen <- sample(pool, n)
while(any(rowSums(outer(paste0('^', chosen), chosen, Vectorize(grepl))) > 1)) {
  prefixes <- rowSums(outer(paste0('^', chosen), chosen, Vectorize(grepl))) > 1
  pool <- pool[rowSums(Vectorize(grepl, 'pattern')(
    paste0('^', chosen[!prefixes]), pool)) == 0]
  chosen <- c(chosen[!prefixes], sample(pool, sum(prefixes)))
}

chosen
## [1] "0100" "0101" "0001" "0011" "1000" "111"  "0000" "0110" "1100" "0111"

这可以通过最初从 pool 中删除那些元素来稍微改进,这些元素的包含将意味着 pool 中剩余的元素不足以获取大小为 n 的总样本。例如,当 max_len = 4n > 9 时,我们可以立即从 pool 中删除 01,因为包括其中任何一个,最大样本将是 9(要么0 和八个以 1 开头的 4 字符元素,或 1 和八个以 0 开头的 4 字符元素)。

基于此逻辑,我们可以在获取初始样本之前省略 pool 中的元素,例如:

pool <- pool[
  nchar(pool) > tail(which(n > (2^max_len - rev(2^(0:max_len))[-1] + 1)), 1)]

谁能想到更好的方法?我觉得我忽略了一些更简单的事情。


编辑

为了阐明我的意图,我将池描述为一组分支,其中连接点和尖端是节点(pool 的元素)。假设绘制了下图中的黄色节点(即010)。现在,由节点 0、01 和 010 组成的整个红色 "branch" 从池中删除。这就是我的意思,禁止对样本中已有 "prefix" 个节点的节点进行采样(以及样本中已经 前缀为 的节点)。

如果采样的节点是half-way沿着一个分支,比如下图中的01,那么所有的红色节点(0、01、010、011)都是不允许的,因为0前缀为01,而01 是 010 和 011 的前缀。

我不是要在每个路口采样 或者 1 0(即沿着树枝行走,在叉子处抛硬币) - 示例中同时包含两者都可以,只要:(1) parents(或 grand-parents 等)或 children(grandchildren 等)节点的尚未采样; (2) 在对节点进行采样后,将有足够的节点剩余以达到所需的样本大小 n

在上面的第二个图中,如果010是第一个pick,那么黑色节点的所有节点仍然(当前)有效,假设n <= 4。例如,如果 n==4 并且我们接下来对节点 1 进行采样(因此我们的选择现在包括 01 和 1),我们随后将不允许节点 00(由于上面的规则 2)但仍然可以选择 000 和 001,给我们我们的 4 元素样本。另一方面,如果 n==5,节点 1 在这个阶段将被禁止。

一种方法是使用迭代方法简单地生成所有可能的适当大小的元组:

  1. 构建所有大小为 1 的元组(pool 中的所有元素)
  2. pool
  3. 中的元素进行叉积
  4. 删除多次使用 pool 的相同元素的任何元组
  5. 删除另一个元组的任何完全重复项
  6. 删除任何一对不能一起使用的元组
  7. 冲洗并重复直到获得合适的元组大小

对于给定的大小(pool 长度 30,max_len 4)这是可运行的:

get.template <- function(pool, max_len) {
  banned <- which(outer(paste0('^', pool), pool, Vectorize(grepl)), arr.ind=T)
  banned <- banned[banned[,1] != banned[,2],]
  banned <- paste(banned[,1], banned[,2])
  vals <- matrix(seq(length(pool)))
  for (k in 2:max_len) {
    vals <- cbind(vals[rep(1:nrow(vals), each=length(pool)),],
                  rep(1:length(pool), nrow(vals)))
    # Can't sample same value more than once
    vals <- vals[apply(vals, 1, function(x) length(unique(x)) == length(x)),]
    # Sort rows to ensure unique only
    vals <- t(apply(vals, 1, sort))
    vals <- unique(vals)
    # Can't have banned pair
    combos <- combn(ncol(vals), 2)
    for (k in seq(ncol(combos))) {
        c1 <- combos[1,k]
        c2 <- combos[2,k]
        vals <- vals[!paste(vals[,c1], vals[,c2]) %in% banned,]
    }
  }
  return(matrix(pool[vals], nrow=nrow(vals)))
}

max_len <- 4
pool <- unlist(lapply(seq_len(max_len), function(x) do.call(paste0, expand.grid(rep(list(c('0', '1')), x)))))
system.time(template <- get.template(pool, 4))
#   user  system elapsed 
#  4.549   0.050   4.614 

现在您可以根据需要从 template 的行中进行多次采样(这将非常快),这与从定义的 space 中随机采样相同。

如果您不想生成所有可能元组的集合然后随机抽样(正如您所注意到的,这对于大输入大小可能不可行),另一种选择是使用整数规划绘制单个样本。基本上,您可以为 pool 中的每个元素分配一个随机值,然后 select 具有最大值和的可行元组。这应该使每个元组被 selected 的概率相等,因为它们的大小都相同,并且它们的值是随机 selected 的。模型的约束将确保 none 的不允许的元组对被 selected 并且正确数量的元素被 selected.

这是 lpSolve 包的解决方案:

library(lpSolve)
sample.lp <- function(pool, max_len) {
  pool <- sort(pool)
  pml <- max(nchar(pool))
  runs <- c(rev(cumsum(2^(seq(pml-1)))), 0)
  banned.from <- rep(seq(pool), runs[nchar(pool)])
  banned.to <- banned.from + unlist(lapply(runs[nchar(pool)], seq_len))
  banned.constr <- matrix(0, nrow=length(banned.from), ncol=length(pool))
  banned.constr[cbind(seq(banned.from), banned.from)] <- 1
  banned.constr[cbind(seq(banned.to), banned.to)] <- 1
  mod <- lp(direction="max",
            objective.in=runif(length(pool)),
            const.mat=rbind(banned.constr, rep(1, length(pool))),
            const.dir=c(rep("<=", length(banned.from)), "=="),
            const.rhs=c(rep(1, length(banned.from)), max_len),
            all.bin=TRUE)
  pool[which(mod$solution == 1)]
}
set.seed(144)
pool <- unlist(lapply(seq_len(4), function(x) do.call(paste0, expand.grid(rep(list(c('0', '1')), x)))))
sample.lp(pool, 4)
# [1] "0011" "010"  "1000" "1100"
sample.lp(pool, 8)
# [1] "0000" "0100" "0110" "1001" "1010" "1100" "1101" "1110"

这似乎可以扩展到相当大的池。例如,从大小为 510 的池中获取长度为 20 的样本需要 2 秒多一点:

pool <- unlist(lapply(seq_len(8), function(x) do.call(paste0, expand.grid(rep(list(c('0', '1')), x)))))
length(pool)
# [1] 510
system.time(sample.lp(pool, 20))
#    user  system elapsed 
#   0.232   0.008   0.239 

如果您真的需要解决非常大的问题,那么您可以从 lpSolve 附带的非开源求解器转移到商业求解器,如 gurobi 或 cplex(一般来说不是免费的,但免费供学术使用。

您可以对池进行排序以帮助决定取消哪些元素。例如,查看一个三元素排序池:

 [1] "0"   "00"  "000" "001" "01"  "010" "011" "1"   "10"  "100" "101" "11" 
[13] "110" "111"

我可以告诉我,我可以取消任何在我的 selected 项目之后字符数多于我的项目直到第一个具有相同或更少字符数的项目的资格。例如,如果我 select "01",我可以立即看到接下来的两项 ("010", "011") 需要删除,但不需要删除后面的一项,因为 "1" 的字符较少.之后删除“0”很容易。这是一个实现:

library(fastmatch)  # could use `match`, but we repeatedly search against same hash

# `pool` must be sorted!

sample01 <- function(pool, n) {
  picked <- logical(length(pool))
  chrs <- nchar(pool)
  pick.list <- character(n)
  pool.seq <- seq_along(pool)

  for(i in seq(n)) {
    # Make sure pool not exhausted

    left <- which(!picked)
    left.len <- length(left)
    if(!length(left)) break

    # Sample from pool

    seq.left <- seq.int(left)
    pool.left <- pool[left]
    chrs.left <- chrs[left]
    pick <- sample(length(pool.left), 1L)

    # Find all the elements with more characters that are disqualified
    # and store their indices in `valid` (bad name...)

    valid.tmp <- chrs.left > chrs.left[[pick]] & seq.left > pick
    first.invalid <- which(!valid.tmp & seq.left > pick)
    valid <- if(length(first.invalid)) {
      pick:(first.invalid[[1L]] - 1L)
    } else pick:left.len

    # Translate back to original pool indices since we're working on a 
    # subset in `pool.left`

    pool.seq.left <- pool.seq[left]
    pool.idx <- pool.seq.left[valid]
    val <- pool[[pool.idx[[1L]]]]

    # Record the picked value, and all the disqualifications

    pick.list[[i]] <- val
    picked[pool.idx] <- TRUE

    # Disqualify shorter matches

    to.rem <- vapply(
      seq.int(nchar(val) - 1), substr, character(1L), x=val, start=1L
    )
    to.rem.idx <- fmatch(to.rem, pool, nomatch=0)
    picked[to.rem.idx] <- TRUE  
  }
  pick.list  
}

以及一个用于创建排序池的函数(与您的代码完全相同,但 returns 已排序):

make_pool <- function(size)
  sort(
    unlist(
      lapply(
        seq_len(size), 
        function(x) do.call(paste0, expand.grid(rep(list(c('0', '1')), x))) 
  ) ) )

然后,使用 max_len 3 池(用于目视检查是否按预期运行):

pool3 <- make_pool(3)
set.seed(1)
sample01(pool3, 8)
# [1] "001" "1"   "010" "011" "000" ""    ""    ""   
sample01(pool3, 8)
# [1] "110" "111" "011" "10"  "00"  ""    ""    ""   
sample01(pool3, 8)
# [1] "000" "01"  "11"  "10"  "001" ""    ""    ""   
sample01(pool3, 8)
# [1] "011" "101" "111" "001" "110" "100" "000" "010"    

请注意,在最后一种情况下,我们得到了所有 3 位二进制组合 (2 ^ 3),因为我们偶然从 3 位二进制组合中进行了采样。此外,只有一个 3 尺寸的池,有很多样本可以防止完整的 8 次抽奖;您可以通过消除阻止从池中完全抽取的组合的建议来解决这个问题。

这相当快。查看 max_len==9 使用替代解决方案耗时 2 秒的示例:

pool9 <- make_pool(9)
microbenchmark(sample01(pool9, 4))
# Unit: microseconds
#                expr     min      lq  median      uq     max neval
#  sample01(pool9, 4) 493.107 565.015 571.624 593.791 983.663   100    

大约半毫秒。您也可以合理地尝试相当大的池:

pool16 <- make_pool(16)  # 131K entries
system.time(sample01(pool16, 100))
#  user  system elapsed 
# 3.407   0.146   3.552 

这不是很快,但我们讨论的是一个包含 13 万个项目的池。还有可能进行额外的优化。

请注意,对于大型池,排序步骤相对较慢,但我没有计算它,因为您只需要执行一次,并且您可能会想出一个合理的算法来生成池预已排序。

我在现已删除的答案中探索了一种更快的基于整数到二进制的方法的可能性,但这需要更多的工作才能准确地与您正在寻找的内容联系起来。

简介

这是我们在其他答案中实施的字符串算法的数字变体。它速度更快,不需要创建或排序池。

算法大纲

我们可以使用整数来表示你的二进制字符串,这大大简化了池生成和顺序消除值的问题。例如,对于 max_len==3,我们可以将数字 1--(其中 - 表示填充)表示为十进制的 4。进一步,我们可以确定,如果我们选择这个数字,则需要消除的数字是44 + 2 ^ x - 1之间的数字。这里 x 是填充元素的数量(在本例中为 2),因此要消除的数字在 44 + 2 ^ 2 - 1 之间(或在 4 和 [=25= 之间) ], 表示为 100, 110, 和 111).

为了准确匹配您的问题,我们需要一些改进,因为您将二进制中可能相同的数字视为算法某些部分的不同数字。例如,10010-1--都是同一个数字,但在你的scheme中需要区别对待。在 max_len==3 世界中,我们有 8 个可能的数字,但有 14 个可能的表示形式:

0 - 000: 0--, 00-
1 - 001:
2 - 010: 01-
3 - 011:
4 - 100: 1--, 10-
5 - 101:
6 - 110: 11-
7 - 111:

所以 0 和 4 有三种可能的编码,2 和 6 有两种,所有其他只有一种。我们需要生成一个整数池,它代表具有多重表示的数字的更高选择概率,以及跟踪数字包含多少空白的机制。我们可以通过在数字末尾附加几位来指示我们想要的权重来做到这一点。所以我们的数字变成(我们在这里使用两位):

jbaum | int | bin | bin.enc | int.enc    
  0-- |   0 | 000 |   00000 |       0
  00- |   0 | 000 |   00001 |       1      
  000 |   0 | 000 |   00010 |       2      
  001 |   1 | 001 |   00100 |       3      
  01- |   2 | 010 |   01000 |       4  
  010 |   2 | 010 |   01001 |       5  
  011 |   3 | 011 |   01101 |       6  
  1-- |   4 | 100 |   10000 |       7  
  10- |   4 | 100 |   10001 |       8  
  100 |   4 | 100 |   10010 |       9  
  101 |   5 | 101 |   10100 |      10  
  11- |   6 | 110 |   11000 |      11   
  110 |   6 | 110 |   11001 |      12   
  111 |   7 | 111 |   11100 |      13

一些有用的属性:

  • enc.bits代表我们编码需要多少位(本例中为2位)
  • int.enc %% enc.bits 告诉我们有多少数字被明确指定
  • int.enc %/% enc.bits returns int
  • int * 2 ^ enc.bits + explicitly.specified returns int.enc

请注意,这里的 explicitly.specified 在我们的实现中介于 0max_len - 1 之间,因为始终至少指定了一位数字。我们现在有了一种仅使用整数即可完全表示您的数据结构的编码。我们可以从整数中采样并使用正确的权重等重现您想要的结果。这种方法的一个限制是我们在 R 中使用 32 位整数,我们必须为编码保留一些位,因此我们将自己限制在池中max_len==25 左右。如果您使用由双精度浮点指定的整数,您可以变得更大,但我们没有在这里这样做。

避免重复选择

有两种粗略的方法可以确保我们不会两次选择相同的值

  1. 跟踪哪些值仍可供选择,并从中随机抽样
  2. 从所有可能的值中随机抽样,然后检查该值是否已被选中,如果有,则再次抽样

虽然第一个选项看起来最干净,但实际上它在计算上非常昂贵。它需要对每个选择的所有可能值进行向量扫描以预先取消选择的值,或者创建一个包含非取消资格值的收缩向量。如果通过 C 代码通过引用使向量收缩,则收缩选项仅比向量扫描更有效,但即使那样,它也需要对向量的潜在大部分进行重复翻译,并且它需要 C.

这里我们使用方法#2。这允许我们随机打乱可能值的范围一次,然后按顺序选择每个值,检查它是否没有被取消资格,如果有,选择另一个,等等。这是有效的,因为检查一个是否是微不足道的值已被选为我们的值编码的结果; 我们可以仅根据值 推断出值在已排序table 中的位置。因此,我们将每个值的状态记录在排序的 table 中,并且可以通过直接索引访问(无需扫描)更新或查找该状态。

例子

此算法在 base R 中的实现可用 a gist。这个特定的实现只拉完整的抽奖。这是 max_len==4 池中 8 个元素的 10 次抽取示例:

# each column represents a draw from a `max_len==4` pool

set.seed(6); replicate(10, sample0110b(4, 8))
     [,1]   [,2]   [,3]   [,4]   [,5]   [,6]   [,7]   [,8]   [,9]   [,10] 
[1,] "1000" "1"    "0011" "0010" "100"  "0011" "0"    "011"  "0100" "1011"
[2,] "111"  "0000" "1101" "0000" "0110" "0100" "1000" "00"   "0101" "1001"
[3,] "0011" "0110" "1001" "0100" "0000" "0101" "1101" "1111" "10"   "1100"
[4,] "0100" "0010" "0000" "0101" "1101" "101"  "1011" "1101" "0110" "1101"
[5,] "101"  "0100" "1100" "1100" "0101" "1001" "1001" "1000" "1111" "1111"
[6,] "110"  "0111" "1011" "111"  "1011" "110"  "1111" "0100" "0011" "000" 
[7,] "0101" "0101" "111"  "011"  "1010" "1000" "1100" "101"  "0001" "0101"
[8,] "011"  "0001" "01"   "1010" "0011" "1110" "1110" "1001" "110"  "1000"

我们最初也有两个依赖于方法#1 的实现来避免重复,一个在基础 R 中,一个在 C 中,但是当 n 很大。这些函数确实实现了绘制不完整图的能力,所以我们在这里提供它们以供参考:

比较基准

这是一组基准测试,比较了 Q/A 中出现的几个函数。以毫秒为单位的时间。 brodie.b 版本是此答案中描述的版本。 brodie 是原始实现,brodie.C 是带有一些 C 的原始实现。所有这些都强制执行完整样本的要求。 brodie.str 是另一个答案中基于字符串的版本。

   size    n  jbaum josilber  frank tensibai brodie.b brodie brodie.C brodie.str
1     4   10     11        1      3        1        1      1        1          0
2     4   50      -        -      -        1        -      -        -          1
3     4  100      -        -      -        1        -      -        -          0
4     4  256      -        -      -        1        -      -        -          1
5     4 1000      -        -      -        1        -      -        -          1
6     8   10      1      290      6        3        2      2        1          1
7     8   50    388        -      8        8        3      4        3          4
8     8  100  2,506        -     13       18        6      7        5          5
9     8  256      -        -     22       27       13     14       12          6
10    8 1000      -        -      -       27        -      -        -          7
11   16   10      -        -    615      688       31     61       19        424
12   16   50      -        -  2,123    2,497       28    276       19      1,764
13   16  100      -        -  4,202    4,807       30    451       23      3,166
14   16  256      -        - 11,822   11,942       40  1,077       43      8,717
15   16 1000      -        - 38,132   44,591       83  3,345      130     27,768

这可以相对较好地扩展到更大的池

system.time(sample0110b(18, 100000))
   user  system elapsed 
  8.441   0.079   8.527 

基准说明:

  • frank 和 brodie(减去 brodie.str)不需要任何池的预生成,这会影响比较(见下文)
  • Josilber 是 LP 版本
  • jbaum 是 OP 示例
  • tensibai 稍微修改为在池为空时退出而不是失败
  • 未设置为 运行 python,因此无法完全比较/解释缓冲
  • - 表示不可行的选项或太慢而无法合理安排时间

时间不包括绘制池(0.82.5401 毫秒,大小分别为 4816),这是 jbaumjosilberbrodie.str 运行 所必需的,或者对它们进行排序(0.12.7 , 3700 毫秒大小 4, 8, 和 16), 这是 brodie.str 除了平局所必需的。是否要包含这些取决于您 运行 特定池的功能的次数。此外,几乎可以肯定有更好的方法来生成/排序池。

这些是 microbenchmark 三个 运行 秒的中间时间。代码是 available as a gist,但请注意,您必须加载 sample0110bsample0110sample01101sample01 预先运行。

将 id 映射到字符串。您可以将数字映射到 0/1 向量,如@BrodieG 所述:

# some key objects

n_pool      = sum(2^(1:max_len))      # total number of indices
cuts        = cumsum(2^(1:max_len-1)) # new group starts
inds_by_g   = mapply(seq,cuts,cuts*2) # indices grouped by length

# the mapping to strings (one among many possibilities)

library(data.table)
get_01str <- function(id,max_len){
    cuts = cumsum(2^(1:max_len-1))
    g    = findInterval(id,cuts)
    gid  = id-cuts[g]+1

    data.table(g,gid)[,s:=
      do.call(paste,c(list(sep=""),lapply(
        seq(g[1]), 
        function(x) (gid-1) %/% 2^(x-1) %% 2
      )))
    ,by=g]$s      
} 

正在寻找要删除的 ID。 我们将从采样池中依次删除 ids:

 # the mapping from one index to indices of nixed strings

get_nixstrs <- function(g,gid,max_len){

    cuts         = cumsum(2^(1:max_len-1))
    gids_child   = {
      x = gid%%2^sequence(g-1)
      ifelse(x,x,2^sequence(g-1))
    }
    ids_child    = gids_child+cuts[sequence(g-1)]-1

    ids_parent   = if (g==max_len) gid+cuts[g]-1 else {

      gids_par       = vector(mode="list",max_len)
      gids_par[[g]]  = gid
      for (gg in seq(g,max_len-1)) 
        gids_par[[gg+1]] = c(gids_par[[gg]],gids_par[[gg]]+2^gg)

      unlist(mapply(`+`,gids_par,cuts-1))
    }

    c(ids_child,ids_parent)
}

索引按 g、字符数 nchar(get_01str(id)) 分组。因为索引按 g 排序,g=findInterval(id,cuts) 是更快的路线。

g1 < g < max_len 中的一个索引有一个 "child" 个大小为 g-1 的索引和两个 parent 个大小为 g+1 的索引.对于每个 child 节点,我们获取其 child 节点,直到我们命中 g==1;对于每个 parent 节点,我们获取它们的一对 parent 节点,直到我们达到 g==max_len.

就组内标识符而言,树的结构最简单,gidgid 映射到两个 parent,gidgid+2^g;并反转此映射可以找到 child.

采样

drawem <- function(n,max_len){
    cuts        = cumsum(2^(1:max_len-1))
    inds_by_g   = mapply(seq,cuts,cuts*2)

    oklens = (1:max_len)[ n <= 2^max_len*(1-2^(-(1:max_len)))+1 ]
    okinds = unlist(inds_by_g[oklens])

    mysamp = rep(0,n)
    for (i in 1:n){

        id        = if (length(okinds)==1) okinds else sample(okinds,1)
        g         = findInterval(id,cuts)
        gid       = id-cuts[g]+1
        nixed     = get_nixstrs(g,gid,max_len)

        # print(id); print(okinds); print(nixed)

        mysamp[i] = id
        okinds    = setdiff(okinds,nixed)
        if (!length(okinds)) break
    }

    res <- rep("",n)
    res[seq.int(i)] <- get_01str(mysamp[seq.int(i)],max_len)
    res
}

oklens 部分整合了 OP 的想法,即省略保证无法进行采样的字符串。然而,即使这样做,我们也可能会遵循一条让我们别无选择的采样路径。以 OP 的 max_len=4n=10 为例,我们知道我们必须从考虑中删除 01,但是如果我们的前四次抽奖是 00 会发生什么, 01, 1110?哦,好吧,我想我们运气不好。这就是您实际上应该定义抽样概率的原因。 (OP 有另一个想法,用于确定在每个步骤中哪些节点将导致不可能的状态,但这似乎是一项艰巨的任务。)

插图

# how the indices line up

n_pool = sum(2^(1:max_len)) 
pdt <- data.table(id=1:n_pool)
pdt[,g:=findInterval(id,cuts)]
pdt[,gid:=1:.N,by=g]
pdt[,s:=get_01str(id,max_len)]

# example run

set.seed(4); drawem(5,5)
# [1] "01100" "1"     "0001"  "0101"  "00101"

set.seed(4); drawem(8,4)
# [1] "1100" "0"    "111"  "101"  "1101" "100"  ""     ""  

Benchmarks(比@BrodieG 回答中的要早)

require(rbenchmark)
max_len = 8
n = 8

benchmark(
      jos_lp     = {
        pool <- unlist(lapply(seq_len(max_len),
          function(x) do.call(paste0, expand.grid(rep(list(c('0', '1')), x)))))
        sample.lp(pool, n)},
      bro_string = {pool <- make_pool(max_len);sample01(pool,n)},
      fra_num    = drawem(n,max_len),
      replications=5)[1:5]
#         test replications elapsed relative user.self
# 2 bro_string            5    0.05      2.5      0.05
# 3    fra_num            5    0.02      1.0      0.02
# 1     jos_lp            5    1.56     78.0      1.55

n = 12
max_len = 12
benchmark(
  bro_string={pool <- make_pool(max_len);sample01(pool,n)},
  fra_num=drawem(n,max_len),
  replications=5)[1:5]
#         test replications elapsed relative user.self
# 1 bro_string            5    0.54     6.75      0.51
# 2    fra_num            5    0.08     1.00      0.08

其他答案。还有两个答案:

jos_enum = {pool <- unlist(lapply(seq_len(max_len), 
    function(x) do.call(paste0, expand.grid(rep(list(c('0', '1')), x)))))
  get.template(pool, n)}
bro_num  = sample011(max_len,n)    

我省略了@josilber 的枚举方法,因为它花费的时间太长了;和@BrodieG 的 numeric/index 方法,因为它当时不起作用,但现在可以了。有关更多基准测试,请参阅@BrodieG 的更新答案。

速度与正确性。虽然@josilber 的答案要慢得多(对于枚举方法,apparent要多得多memory-intensive),他们保证在第一次尝试时抽取大小为 n 的样本。使用@BrodieG 的字符串方法或此答案,您将不得不一次又一次地重新采样,以期绘制完整的 n。我想对于大 max_len,这不应该是个问题。

这个答案比 bro_string 更好,因为它不需要预先构建 pool

它在 python 而不是在 r 但 jbaums 说没问题。

这是我的贡献,请参阅源代码中的评论以了解关键部分的解释。
我仍在研究分析解决方案以确定深度 tS 样本的树的可能组合数量 c,因此我可以改进函数 combs .也许有人拥有它? 这真的是现在的瓶颈。

在我的笔记本电脑上从深度为 16 的树中采样 100 个节点大约需要 8 毫秒。 这不是第一次了,但是由于 combBuffer 被填满,采样越多,它在某个点上变得越快。

import random


class Tree(object):
    """
    :param level: The distance of this node from the root.
    :type level: int
    :param parent: This trees parent node
    :type parent: Tree
    :param isleft: Determines if this is a left or a right child node. Can be
                   omitted if this is the root node.
    :type isleft: bool

    A binary tree representing possible strings which match r'[01]{1,n}'. Its
    purpose is to be able to sample n of its nodes where none of the sampled
    nodes' ids is a prefix for another one.
    It is possible to change Tree.maxdepth and then reuse the root. All
    children are created ON DEMAND, which means everything is lazily evaluated.
    If the Tree gets too big anyway, you can call 'prune' on any node to delete
    its children.

        >>> t = Tree()
        >>> t.sample(8, toString=True, depth=3)
        ['111', '110', '101', '100', '011', '010', '001', '000']
        >>> Tree.maxdepth = 2
        >>> t.sample(4, toString=True)
        ['11', '10', '01', '00']
    """

    maxdepth = 10
    _combBuffer = {}

    def __init__(self, level=0, parent=None, isleft=None):
        self.parent = parent
        self.level = level
        self.isleft = isleft
        self._left = None
        self._right = None

    @classmethod
    def setMaxdepth(cls, depth):
        """
        :param depth: The new depth
        :type depth: int

        Sets the maxdepth of the Trees. This basically is the depth of the root
        node.
        """
        if cls.maxdepth == depth:
            return

        cls.maxdepth = depth

    @property
    def left(self):
        """This tree's left child, 'None' if this is a leave node"""
        if self.depth == 0:
            return None

        if self._left is None:
            self._left = Tree(self.level+1, self, True)
        return self._left

    @property
    def right(self):
        """This tree's right child, 'None' if this is a leave node"""
        if self.depth == 0:
            return None

        if self._right is None:
            self._right = Tree(self.level+1, self, False)
        return self._right

    @property
    def depth(self):
        """
        This tree's depth. (maxdepth-level)
        """
        return self.maxdepth-self.level

    @property
    def id(self):
        """
        This tree's id, string of '0's and '1's equal to the path from the root
        to this subtree. Where '1' means going left and '0' means going right.
        """
        # level 0 is the root node, it has no id
        if self.level == 0:
            return ''
        # This takes at most Tree.maxdepth recursions. Therefore
        # it is save to do it this way. We could also save each nodes
        # id once it is created to avoid recreating it every time, however
        # this won't save much time but use quite some space.
        return self.parent.id + ('1' if self.isleft else '0')

    @property
    def leaves(self):
        """
        The amount of leave nodes, this tree has. (2**depth)
        """
        return 2**self.depth

    def __str__(self):
        return self.id

    def __len__(self):
        return 2*self.leaves-1

    def prune(self):
        """
        Recursively prune this tree's children.
        """
        if self._left is not None:
            self._left.prune()
            self._left.parent = None
            self._left = None

        if self._right is not None:
            self._right.prune()
            self._right.parent = None
            self._right = None

    def combs(self, n):
        """
        :param n: The amount of samples to be taken from this tree
        :type n: int
        :returns: The amount of possible combinations to choose n samples from
                  this tree

        Determines recursively the amount of combinations of n nodes to be
        sampled from this tree.
        Subsequent calls with same n on trees with same depth will return the
        result from the previous computation rather than computing it again.

            >>> t = Tree()
            >>> Tree.maxdepth = 4
            >>> t.combs(16)
            1
            >>> Tree.maxdepth = 3
            >>> t.combs(6)
            58
        """

        # important for the amount of combinations is only n and the depth of
        # this tree
        key = (self.depth, n)

        # We use the dict to save computation time. Calling the function with
        # equal values on equal nodes just returns the alrady computed value if
        # possible.
        if key not in Tree._combBuffer:
            leaves = self.leaves

            if n < 0:
                N = 0
            elif n == 0 or self.depth == 0 or n == leaves:
                N = 1
            elif n == 1:
                return (2*leaves-1)
            else:
                if n > leaves/2:
                    # if n > leaves/2, at least n-leaves/2 have to stay on
                    # either side, otherweise the other one would have to
                    # sample more nodes than possible.
                    nMin = n-leaves/2
                else:
                    nMin = 0

                # The rest n-2*nMin is the amount of samples that are free to
                # fall on either side
                free = n-2*nMin

                N = 0
                # sum up the combinations of all possible splits
                for addLeft in range(0, free+1):
                    nLeft = nMin + addLeft
                    nRight = n - nLeft
                    N += self.left.combs(nLeft)*self.right.combs(nRight)

            Tree._combBuffer[key] = N
            return N
        return Tree._combBuffer[key]

    def sample(self, n, toString=False, depth=None):
        """
        :param n: How may samples to take from this tree
        :type n: int
        :param toString: If 'True' result will direclty be turned into a list
                         of strings
        :type toString: bool
        :param depth: If not None, will overwrite Tree.maxdepth
        :type depth: int or None
        :returns: List of n nodes sampled from this tree
        :throws ValueError: when n is invalid

        Takes n random samples from this tree where none of the sample's ids is
        a prefix for another one's.

        For an example see Tree's docstring.
        """
        if depth is not None:
            Tree.setMaxdepth(depth)

        if toString:
            return [str(e) for e in self.sample(n)]

        if n < 0:
            raise ValueError('Negative sample size is not possible!')

        if n == 0:
            return []

        leaves = self.leaves
        if n > leaves:
            raise ValueError(('Cannot sample {} nodes, with only {} ' +
                              'leaves!').format(n, leaves))

        # Only one sample to choose, that is nice! We are free to take any node
        # from this tree, including this very node itself.
        if n == 1 and self.level > 0:
            # This tree has 2*leaves-1 nodes, therefore
            # the probability that we keep the root node has to be
            # 1/(2*leaves-1) = P_root. Lets create a random number from the
            # interval [0, 2*leaves-1).
            # It will be 0 with probability 1/(2*leaves-1)
            P_root = random.randint(0, len(self)-1)
            if P_root == 0:
                return [self]
            else:
                # The probability to land here is 1-P_root

                # A child tree's size is (leaves-1) and since it obeys the same
                # rule as above, the probability for each of its nodes to
                # 'survive' is 1/(leaves-1) = P_child.
                # However all nodes must have equal probability, therefore to
                # make sure that their probability is also P_root we multiply
                # them by 1/2*(1-P_root). The latter is already done, the
                # former will be achieved by the next condition.
                # If we do everything right, this should hold:
                # 1/2 * (1-P_root) * P_child = P_root

                # Lets see...
                # 1/2 * (1-1/(2*leaves-1)) * (1/leaves-1)
                # (1-1/(2*leaves-1)) * (1/(2*(leaves-1)))
                # (1-1/(2*leaves-1)) * (1/(2*leaves-2))
                # (1/(2*leaves-2)) - 1/((2*leaves-2) * (2*leaves-1))
                # (2*leaves-1)/((2*leaves-2) * (2*leaves-1)) - 1/((2*leaves-2) * (2*leaves-1))
                # (2*leaves-2)/((2*leaves-2) * (2*leaves-1))
                # 1/(2*leaves-1)
                # There we go!
                if random.random() < 0.5:
                    return self.right.sample(1)
                else:
                    return self.left.sample(1)

        # Now comes the tricky part... n > 1 therefore we are NOT going to
        # sample this node. Its probability to be chosen is 0!
        # It HAS to be 0 since we are definitely sampling from one of its
        # children which means that this node will be blocked by those samples.
        # The difficult part now is to prove that the sampling the way we do it
        # is really random.

        if n > leaves/2:
            # if n > leaves/2, at least n-leaves/2 have to stay on either
            # side, otherweise the other one would have to sample more
            # nodes than possible.
            nMin = n-leaves/2
        else:
            nMin = 0
        # The rest n-2*nMin is the amount of samples that are free to fall
        # on either side
        free = n-2*nMin

        # Let's have a look at an example, suppose we were to distribute 5
        # samples among two children which have 4 leaves each.
        # Each child has to get at least 1 sample, so the free samples are 3.
        # There are 4 different ways to split the samples among the
        # children (left, right):
        # (1, 4), (2, 3), (3, 2), (4, 1)
        # The amount of unique sample combinations per child are
        # (7, 1), (11, 6), (6, 11), (1, 7)
        # The amount of total unique samples per possible split are
        #   7   ,   66  ,   66  ,    7
        # In case of the first and last split, all samples have a probability
        # of 1/7, this was already proven above.
        # Lets suppose we are good to go and the per sample probabilities for
        # the other two cases are (1/11, 1/6) and (1/6, 1/11), this way the
        # overall per sample probabilities for the splits would be:
        #  1/7  ,  1/66 , 1/66 , 1/7
        # If we used uniform random to determine the split, all splits would be
        # equally probable and therefore be multiplied with the same value (1/4)
        # But this would mean that NOT every sample is equally probable!
        # We need to know in advance how many sample combinations there will be
        # for a given split in order to find out the probability to choose it.
        # In fact, due to the restrictions, this becomes very nasty to
        # determine. So instead of solving it analytically, I do it numerically
        # with the method 'combs'. It gives me the amount of possible sample
        # combinations for a certain amount of samples and a given tree depth.
        # It will return 146 for this node and 7 for the outer and 66 for the
        # inner splits.
        # What we now do is, we take a number from [0, 146).
        # if it is smaller than 7, we sample from the first split,
        # if it is smaller than 7+66, we sample from the second split,
        # ...
        # This way we get the probabilities we need.

        r = random.randint(0, self.combs(n)-1)
        p = 0
        for addLeft in xrange(0, free+1):
            nLeft = nMin + addLeft
            nRight = n - nLeft

            p += (self.left.combs(nLeft) * self.right.combs(nRight))
            if r < p:
                return self.left.sample(nLeft) + self.right.sample(nRight)
        assert False, ('Something really strange happend, p did not sum up ' +
                       'to combs or r was too big')


def main():
    """
    Do a microbenchmark.
    """
    import timeit
    i = 1
    main.t = Tree()
    template = ' {:>2}  {:>5} {:>4}  {:<5}'
    print(template.format('i', 'depth', 'n', 'time (ms)'))
    N = 100
    for depth in [4, 8, 15, 16, 17, 18]:
        for n in [10, 50, 100, 150]:
            if n > 2**depth:
                time = '--'
            else:
                time = timeit.timeit(
                    'main.t.sample({}, depth={})'.format(n, depth), setup=
                    'from __main__ import main', number=N)*1000./N
            print(template.format(i, depth, n, time))
            i += 1


if __name__ == "__main__":
    main()

基准输出:

  i  depth    n  time (ms)
  1      4   10  0.182511806488
  2      4   50  --   
  3      4  100  --   
  4      4  150  --   
  5      8   10  0.397620201111
  6      8   50  1.66054964066
  7      8  100  2.90236949921
  8      8  150  3.48146915436
  9     15   10  0.804011821747
 10     15   50  3.7428188324
 11     15  100  7.34910964966
 12     15  150  10.8230614662
 13     16   10  0.804491043091
 14     16   50  3.66818904877
 15     16  100  7.09567070007
 16     16  150  10.404779911
 17     17   10  0.865840911865
 18     17   50  3.9999294281
 19     17  100  7.70257949829
 20     17  150  11.3758206367
 21     18   10  0.915451049805
 22     18   50  4.22935962677
 23     18  100  8.22361946106
 24     18  150  12.2081303596

10 个大小为 10 的样本,来自深度为 10 的树:

['1111010111', '1110111010', '1010111010', '011110010', '0111100001', '011101110', '01110010', '01001111', '0001000100', '000001010']
['110', '0110101110', '0110001100', '0011110', '0001111011', '0001100010', '0001100001', '0001100000', '0000011010', '0000001111']
['11010000', '1011111101', '1010001101', '1001110001', '1001100110', '10001110', '011111110', '011001100', '0101110000', '001110101']
['11111101', '110111', '110110111', '1101010101', '1101001011', '1001001100', '100100010', '0100001010', '0100000111', '0010010110']
['111101000', '1110111101', '1101101', '1101000000', '1011110001', '0111111101', '01101011', '011010011', '01100010', '0101100110']
['1111110001', '11000110', '1100010100', '101010000', '1010010001', '100011001', '100000110', '0100001111', '001101100', '0001101101']
['111110010', '1110100', '1101000011', '101101', '101000101', '1000001010', '0111100', '0101010011', '0101000110', '000100111']
['111100111', '1110001110', '1100111111', '1100110010', '11000110', '1011111111', '0111111', '0110000100', '0100011', '0010110111']
['1101011010', '1011111', '1011100100', '1010000010', '10010', '1000010100', '0111011111', '01010101', '001101', '000101100']
['111111110', '111101001', '1110111011', '111011011', '1001011101', '1000010100', '0111010101', '010100110', '0100001101', '0010000000']

我发现这个问题很有趣,所以我尝试用非常低的 R 技能来解决这个问题(因此这可能会得到改进):

较新的编辑版本,感谢 @Franck 建议:

library(microbenchmark)
library(lineprof)

max_len <- 16
pool <- unlist(lapply(seq_len(max_len), function(x) 
  do.call(paste0, expand.grid(rep(list(c('0', '1')), x)))))
n<-100

library(stringr)
tree_sample <- function(samples,pool) {
  results <- vector("integer",samples)
  # Will be used on a regular basis, compute it in advance
  PoolLen <- str_length(pool)
  # Make a mask vector based on the length of each entry of the pool
  masks <- strtoi(str_pad(str_pad("1",PoolLen,"right","1"),max_len,"right","0"),base=2)

  # Make an integer vector from "0" right padded orignal: for max_len=4 and pool entry "1" we get "1000" => 8
  # This will allow to find this entry as parent of 10 and 11 which become "1000" and "1100", as integer 8 and 12 respectively
  # once bitwise "anded" with the repective mask "1000" the first bit is striclty the same, so it's a parent.
  integerPool <- strtoi(str_pad(pool,max_len,"right","0"),base=2)

  # Create a vector to filter the available value to sample
  ok <- rep(TRUE,length(pool))

  #Precompute the result of the bitwise and betwwen our integer pool and the masks   
  MaskedPool <- bitwAnd(integerPool,masks)

  while(samples) {
    samp <- sample(pool[ok],1) # Get a sample
    results[samples] <- samp # Store it as result
    ok[pool == samp] <- FALSE # Remove it from available entries

    vsamp <- strtoi(str_pad(samp,max_len,"right","0"),base=2) # Get the integer value of the "0" right padded sample
    mlen <- str_length(samp) # Get sample len

    #Creation of unitary mask to remove childs of sample
    mask <- strtoi(paste0(rep(1:0,c(mlen,max_len-mlen)),collapse=""),base=2)

    # Get the result of bitwise And between the integerPool and the sample mask 
    FilterVec <- bitwAnd(integerPool,mask)

    # Get the bitwise and result of the sample and it's mask
    Childm <- bitwAnd(vsamp,mask)

    ok[FilterVec == Childm] <- FALSE  # Remove from available entries the childs of the sample
    ok[MaskedPool == bitwAnd(vsamp,masks)] <- FALSE # compare the sample with all the masks to remove parents matching

    samples <- samples -1
  }
  print(results)
}
microbenchmark(tree_sample(n,pool),times=10L)

主要思想是使用 bitmask comparison 来了解一个样本是否是另一个样本的父样本(公共位部分),如果是,则从池中抑制该元素。

在我的机器上从长度为 16 的池中抽取 100 个样本现在需要 1.4 秒。

简介

我觉得这个问题很有趣,我不得不仔细考虑一下,并最终给出我自己的答案。由于我得出的算法并不能立即从问题描述中得出,我将首先解释我是如何得出这个解决方案的,然后提供一个用 C++ 实现的示例(我从未写过 R)。

解决方案的开发

乍一看

一开始看问题描述有点懵,但是一看到树的图片编辑,我立刻明白了问题,我的直觉告诉我二叉树也是一个解决方案:建一棵树(大小为 1 的树的集合),并在制作 selection 时消除树枝和祖先后将树分解为更小的树的集合。

虽然这最初看起来不错,但 selection 过程和集合的维护会很痛苦。不过,这棵树似乎应该在任何解决方案中发挥重要作用。

修订版 1

不要拆树。相反,每个节点都有一个布尔数据有效负载,指示它是否已被消除。这只剩下一棵保持形式的树。

但是请注意,这不仅仅是任何二叉树,它实际上是深度为 max_len-1.

的完全二叉树

修订版 2

一个完整的二叉树可以很好地表示为一个数组。典型的数组表示使用了树的breadth-first搜索,具有以下性质:

Let x be the array index.
x = 0 is the root of the entire tree
left_child(x) = 2x + 1
right_child(x) = 2x + 2
parent(x) = floor((n-1)/2)

在下图中,每个节点都标有其数组索引:

作为一个数组,它占用更少的内存(没有更多的指针),使用的内存是连续的(有利于缓存),并且可以完全在堆栈上而不是堆上(假设你的语言给了你一个选择).当然这里有一些条件适用,特别是数组的大小。我稍后再讲。

与修订版 1 一样,存储在数组中的数据将是布尔值:true 表示可用,false 表示不可用。由于根节点实际上不是一个有效的选择,索引 0 应该初始化为 false。如何制作selection的问题仍然存在:

由于指数是有限的,因此跟踪有多少被淘汰,以及有多少仍然存在是微不足道的。在该范围内选择一个随机数,然后遍历数组,直到看到许多索引设置为 true(包括当前索引)。得到的指数就是要制造的 select 离子。 Select 直到 n 个索引被 selected,或者 select.

没有任何剩余

这是一个完整的算法,但是在 selection 过程中还有改进的空间,还有一个尚未解决的实际大小问题:数组大小在将是 O(2^n)。随着 n 变大,首先缓存的好处消失,然后数据开始被分页到磁盘,在某些时候根本无法存储。

修订版 3

我决定先解决更简单的问题:改进 select离子过程。

扫描数组left-to-right很浪费。跟踪已消除的范围可能比检查并发现连续的几个错误更有效。然而,我们的树表示对此并不理想,因为每轮要消除的节点很少在数组中是连续的。

通过重新排列数组映射到树的方式,可以更好地利用它。特别是,让我们使用 pre-order depth-first 搜索而不是 breadth-first 搜索。为了做到这一点,树的大小需要固定,这就是这个问题的情况。 child 和 parent 节点的索引在数学上是如何连接的也不太明显。

通过使用这种安排,每个不是叶子的选择都保证消除一个连续的范围:它的子树。

修订版 4

通过跟踪消除的范围,不再需要 true/false 数据,因此根本不需要数组或树。 在每次随机抽取时,消除的范围可用于快速找到节点 select。所有祖先和整个子树都被消除,可以表示为可以很容易地与其他人合并的范围。

最后的任务是将 selected 节点转换为 OP 想要的字符串表示形式。这很容易,因为这棵二叉树仍然保持严格的顺序:从根开始遍历,所有元素 >= 右边 child 在右边,其他元素在左边。因此,搜索树将通过在向左遍历时附加“0”来提供祖先列表和二进制字符串;或向右移动时为“1”。

实施示例

#include <stdint.h>
#include <algorithm>
#include <cmath>
#include <list>
#include <deque>
#include <ctime>
#include <cstdlib>
#include <iostream>

/*
 * A range of values of the form (a, b), where a <= b, and is inclusive.
 * Ex (1,1) is the range from 1 to 1 (ie: just 1)
 */
class Range
{
private:
    friend bool operator< (const Range& lhs, const Range& rhs);
    friend std::ostream& operator<<(std::ostream& os, const Range& obj);

    int64_t m_start;
    int64_t m_end;

public:
    Range(int64_t start, int64_t end) : m_start(start), m_end(end) {}
    int64_t getStart() const { return m_start; }
    int64_t getEnd() const { return m_end; }
    int64_t size() const { return m_end - m_start + 1; }
    bool canMerge(const Range& other) const {
        return !((other.m_start > m_end + 1) || (m_start > other.m_end + 1));
    }
    int64_t merge(const Range& other) {
        int64_t change = 0;
        if (m_start > other.m_start) {
            change += m_start - other.m_start;
            m_start = other.m_start;
        }
        if (other.m_end > m_end) {
            change += other.m_end - m_end;
            m_end = other.m_end;
        }
        return change;
    }
};

inline bool operator< (const Range& lhs, const Range& rhs){return lhs.m_start < rhs.m_start;}
std::ostream& operator<<(std::ostream& os, const Range& obj) {
    os << '(' << obj.m_start << ',' << obj.m_end << ')';
    return os;
}

/*
 * Stuct to allow returning of multiple values
 */
struct NodeInfo {
    int64_t subTreeSize;
    int64_t depth;
    std::list<int64_t> ancestors;
    std::string representation;
};

/*
 * Collection of functions representing a complete binary tree
 * as an array created using pre-order depth-first search,
 * with 0 as the root.
 * Depth of the root is defined as 0.
 */
class Tree
{
private:
    int64_t m_depth;
public:
    Tree(int64_t depth) : m_depth(depth) {}
    int64_t size() const {
        return (int64_t(1) << (m_depth+1))-1;
    }
    int64_t getDepthOf(int64_t node) const{
        if (node == 0) { return 0; }
        int64_t searchDepth = m_depth;
        int64_t currentDepth = 1;
        while (true) {
            int64_t rightChild = int64_t(1) << searchDepth;
            if (node == 1 || node == rightChild) {
                break;
            } else if (node > rightChild) {
                node -= rightChild;
            } else {
                node -= 1;
            }
            currentDepth += 1;
            searchDepth -= 1;
        }
        return currentDepth;
    }
    int64_t getSubtreeSizeOf(int64_t node, int64_t nodeDepth = -1) const {
        if (node == 0) {
            return size();
        }
        if (nodeDepth == -1) {
            nodeDepth = getDepthOf(node);
        }
        return (int64_t(1) << (m_depth + 1 - nodeDepth)) - 1;
    }
    int64_t getLeftChildOf(int64_t node, int64_t nodeDepth = -1) const {
        if (nodeDepth == -1) {
            nodeDepth = getDepthOf(node);
        }
        if (nodeDepth == m_depth) { return -1; }
        return node + 1;
    }
    int64_t getRightChildOf(int64_t node, int64_t nodeDepth = -1) const {
        if (nodeDepth == -1) {
            nodeDepth = getDepthOf(node);
        }
        if (nodeDepth == m_depth) { return -1; }
        return node + 1 + ((getSubtreeSizeOf(node, nodeDepth) - 1) / 2);
    }
    NodeInfo getNodeInfo(int64_t node) const {
        NodeInfo info;
        int64_t depth = 0;
        int64_t currentNode = 0;
        while (currentNode != node) {
            if (currentNode != 0) {
                info.ancestors.push_back(currentNode);
            }
            int64_t rightChild = getRightChildOf(currentNode, depth);
            if (rightChild == -1) {
                break;
            } else if (node >= rightChild) {
                info.representation += '1';
                currentNode = rightChild;
            } else {
                info.representation += '0';
                currentNode = getLeftChildOf(currentNode, depth);
            }
            depth++;
        }
        info.depth = depth;
        info.subTreeSize = getSubtreeSizeOf(node, depth);
        return info;
    }
};

// random selection amongst remaining allowed nodes
int64_t selectNode(const std::deque<Range>& eliminationList, int64_t poolSize, std::mt19937_64& randomGenerator)
{
    std::uniform_int_distribution<> randomDistribution(1, poolSize);
    int64_t selection = randomDistribution(randomGenerator);
    for (auto const& range : eliminationList) {
        if (selection >= range.getStart()) { selection += range.size(); }
        else { break; }
    }
    return selection;
}

// determin how many nodes have been elimintated
int64_t countEliminated(const std::deque<Range>& eliminationList)
{
    int64_t count = 0;
    for (auto const& range : eliminationList) {
        count += range.size();
    }
    return count;
}

// merge all the elimination ranges to listA, and return the number of new elimintations
int64_t mergeEliminations(std::deque<Range>& listA, std::deque<Range>& listB) {
    if(listB.empty()) { return 0; }
    if(listA.empty()) {
        listA.swap(listB);
        return countEliminated(listA);
    }

    int64_t newEliminations = 0;
    int64_t x = 0;
    auto listA_iter = listA.begin();
    auto listB_iter = listB.begin();
    while (listB_iter != listB.end()) {
        if (listA_iter == listA.end()) {
            listA_iter = listA.insert(listA_iter, *listB_iter);
            x = listB_iter->size();
            assert(x >= 0);
            newEliminations += x;
            ++listB_iter;
        } else if (listA_iter->canMerge(*listB_iter)) {
            x = listA_iter->merge(*listB_iter);
            assert(x >= 0);
            newEliminations += x;
            ++listB_iter;
        } else if (*listB_iter < *listA_iter) {
            listA_iter = listA.insert(listA_iter, *listB_iter) + 1;
            x = listB_iter->size();
            assert(x >= 0);
            newEliminations += x;
            ++listB_iter;
        } else if ((listA_iter+1) != listA.end() && listA_iter->canMerge(*(listA_iter+1))) {
            listA_iter->merge(*(listA_iter+1));
            listA_iter = listA.erase(listA_iter+1);
        } else {
            ++listA_iter;
        }
    }
    while (listA_iter != listA.end()) {
        if ((listA_iter+1) != listA.end() && listA_iter->canMerge(*(listA_iter+1))) {
            listA_iter->merge(*(listA_iter+1));
            listA_iter = listA.erase(listA_iter+1);
        } else {
            ++listA_iter;
        }
    }
    return newEliminations;
}

int main (int argc, char** argv)
{
    std::random_device rd;
    std::mt19937_64 randomGenerator(rd());

    int64_t max_len = std::stoll(argv[1]);
    int64_t num_samples = std::stoll(argv[2]);

    int64_t samplesRemaining = num_samples;
    Tree tree(max_len);
    int64_t poolSize = tree.size() - 1;
    std::deque<Range> eliminationList;
    std::deque<Range> eliminated;
    std::list<std::string> foundList;

    while (samplesRemaining > 0 && poolSize > 0) {
        // find a valid node
        int64_t selectedNode = selectNode(eliminationList, poolSize, randomGenerator);
        NodeInfo info = tree.getNodeInfo(selectedNode);
        foundList.push_back(info.representation);
        samplesRemaining--;

        // determine which nodes this choice eliminates
        eliminated.clear();
        for( auto const& ancestor : info.ancestors) {
            Range r(ancestor, ancestor);
            if(eliminated.empty() || !eliminated.back().canMerge(r)) {
                eliminated.push_back(r);
            } else {
                eliminated.back().merge(r);
            }
        }
        Range r(selectedNode, selectedNode + info.subTreeSize - 1);
        if(eliminated.empty() || !eliminated.back().canMerge(r)) {
            eliminated.push_back(r);
        } else {
            eliminated.back().merge(r);
        }

        // add the eliminated nodes to the existing list
        poolSize -= mergeEliminations(eliminationList, eliminated);
    }

    // Print some stats
    // std::cout << "tree: " << tree.size() << " samplesRemaining: "
    //                       << samplesRemaining << " poolSize: "
    //                       << poolSize << " samples: " << foundList.size()
    //                       << " eliminated: "
    //                       << countEliminated(eliminationList) << std::endl;

    // Print list of binary strings
    // std::cout << "list:";
    // for (auto const& s : foundList) {
    //  std::cout << " " << s;
    // }
    // std::cout << std::endl;
}

其他想法

此算法对于 max_len 的扩展性非常好。用 n 缩放不是 ver很好,但根据我自己的分析,它似乎比其他解决方案做得更好。

可以不费吹灰之力修改此算法,以允许字符串不仅仅包含“0”和“1”。 单词中更多可能的字母[=7​​5=]会增加树的fan-out,并且每个select会消除更宽的范围ion - 每个子树中的所有节点仍然是连续的。