R中的采样子矩阵

Sampling submatrices in R

给定一个矩阵 mat

> set.seed(1)
> mat <- matrix(rbinom(100,1,0.5),10,10)
> rownames(mat) <- paste0(sample(LETTERS[1:2],10,replace=T),c(1:nrow(mat)))
> colnames(mat) <- paste0(sample(LETTERS[1:2],10,replace=T),c(1:ncol(mat)))
> mat
    A1 A2 A3 B4 B5 B6 B7 A8 B9 B10
B1   0  0  1  0  1  0  1  0  0   0
B2   0  0  0  1  1  1  0  1  1   0
B3   1  1  1  0  1  0  0  0  0   1
A4   1  0  0  0  1  0  0  0  0   1
A5   0  1  0  1  1  0  1  0  1   1
A6   1  0  0  1  1  0  0  1  0   1
A7   1  1  0  1  0  0  0  1  1   0
B8   1  1  0  0  0  1  1  0  0   0
A9   1  0  1  1  1  1  0  1  0   1
A10  0  1  0  0  1  0  1  1  0   1

我想对以下形式的子矩阵进行采样:

   A  B
A  0  1
B  1  0

编辑:具体来说,子矩阵应在 A-row/B-column 中包含 1,在 B-row/A-column 中包含 1,在其他两个单元格中包含 0。

我可以通过随机选择一个A行和一个B行,然后随机选择一个A列和一个B列,然后检查它是否具有所需的模式来做到这一点。但是,我试图找到一种更有效的方法,即使在 large/sparse 矩阵中也能工作。谢谢!

您可以 sample 维度名称:

set.seed(1)
dims <- lapply(dimnames(mat), \(x) c(sample(which(grepl("A", x)), 1), sample(which(grepl("B", x)), 1)))

mat[dims[[1]], dims[[2]]]

   A3 B4
A4  0  0
B8  0  0

可以枚举包含值 1 的元素的所有可能成对组合,然后消除共享行或列的对以及不会导致子矩阵主对角线的 0 元素的对.从剩余的每一对中得到的行和列将定义满足所需模式的所有可能的子矩阵,并且这些子矩阵对于采样来说是微不足道的。这对于具有相对较少数量 1 元素的矩阵是可行的(例如,<100K——取决于可用内存)。

对于具有大量 1 元素的稀疏矩阵,获得高效矢量化解决方案的一种直接方法也是拒绝:为每个子矩阵的反对角线和拒绝样本对 1 元素如果对应的主对角线元素不是0。下面的解决方案假设更多的元素是 0 而不是 1。 (如果相反,则应修改为主对角线采样两个 0 个元素,如果反对角元素不是 1 则拒绝。)拒绝率将主要取决于密度稀疏矩阵。示例矩阵比较密集,所以拒绝率比较高。

library(data.table)
library(Matrix)

set.seed(1)
m <- matrix(rbinom(100,1,0.5),10,10)
n <- 20L # sample 20 pairs (before rejection)
m <- as(m, "ngTMatrix")
mIdx <- matrix(sample(length(m@i), 2L*n, TRUE), ncol = 2)
(data.table(
  row1 = m@i[mIdx[,1]],
  col1 = m@j[mIdx[,2]],
  row2 = m@i[mIdx[,2]],
  col2 = m@j[mIdx[,1]]
) + 1L)[
  row1 != row2 & col1 != col2 & !(m[matrix(c(row1, col1), n)] + m[matrix(c(row2, col2), n)])
]
#>    row1 col1 row2 col2
#> 1:    1    4    2    7
#> 2:   10    6    2   10
#> 3:    7    7    8    9

这里它被实现为一个函数,returns 指定数量的样本有或没有替换。

sampleSubMat <- function(m, n, replace = FALSE, maxIter = 10L) {
  # convert m to a sparse matrix in triplet format if it's not already
  if (!grepl("TMatrix", class(m))) m <- as(1*m, "dgTMatrix")
  nLeft <- n
  # over-sample based on dimensions and density of the matrix
  mult <- 1.1/(1 - length(m@i)/prod(dim(m)))^2/prod(1 - 1/(dim(m - 1)))
  iter <- 1L
  
  if (replace) { # sampling with replacement (duplicates allowed)
    # more efficient to store individual data.table objects from each
    # iteration in a list, then rbindlist at the end
    lDT <- vector("list", maxIter)
    
    while (nLeft > 0L) {
      if (iter > maxIter) {
        message(sprintf("Max iterations (%i) reached", maxIter))
        # print("Max iterations reached")
        return(rbindlist(lDT[1:(iter - 1L)])[1:n])
      }
      nCurr <- ceiling(nLeft*mult)
      mIdx <- matrix(sample(length(m@i), 2L*nCurr, TRUE), ncol = 2)
      lDT[[iter]] <- (data.table(
        row1 = m@i[mIdx[,1]],
        col1 = m@j[mIdx[,2]],
        row2 = m@i[mIdx[,2]],
        col2 = m@j[mIdx[,1]]
      ) + 1L)[
        row1 != row2 & col1 != col2 & !(m[matrix(c(row1, col1), nCurr)] + m[matrix(c(row2, col2), nCurr)])
      ]
      if (nrow(lDT[[iter]])) {
        mult <- 1.1*mult*nLeft/nrow(lDT[[iter]])
        nLeft <- nLeft - nrow(lDT[[iter]])
        iter <- iter + 1L
      } else {
        # no pattern found, double the samples for the next iteration
        mult <- mult*2
      }
    }
    rbindlist(lDT[1:(iter - 1L)])[1:n]
  } else { # sampling without replacement (no duplicates allowed)
    # rbindlist on each iteration to check for duplicates
    dtOut <- data.table(
      row1 = integer(0), col1 = integer(0),
      row2 = integer(0), col2 = integer(0)
    )
    while (nLeft > 0L) {
      if (iter > maxIter) {
        message(sprintf("Max iterations (%i) reached", maxIter))
        return(dtOut)
      }
      nCurr <- ceiling(nLeft*mult)
      mIdx <- matrix(sample(length(m@i), 2L*nCurr, TRUE), ncol = 2)
      dt <- (data.table(
        row1 = m@i[mIdx[,1]],
        col1 = m@j[mIdx[,2]],
        row2 = m@i[mIdx[,2]],
        col2 = m@j[mIdx[,1]]
      ) + 1L)[
        row1 != row2 & col1 != col2 & !(m[matrix(c(row1, col1), nCurr)] + m[matrix(c(row2, col2), nCurr)])
      ]
      if (nrow(dt)) {
        dtOut <- unique(rbindlist(list(dtOut, dt)))
        mult <- 1.1*mult*nLeft/(nrow(dtOut) - n + nLeft)
        nLeft <- nLeft - nrow(dtOut)
      } else {
        mult <- mult*2
      }
    }
    dtOut[1:n]
  }
}

(dtSamples1 <- sampleSubMat(m, 10L))
#>     row1 col1 row2 col2
#>  1:    3    6    2   10
#>  2:    3    8    7    5
#>  3:    9    7    5    1
#>  4:    5    8   10    9
#>  5:    7    7    1    8
#>  6:    3    8    9    2
#>  7:    5    1    3    9
#>  8:    5    1    8    5
#>  9:    4    7    1    1
#> 10:   10    6    8    8
(dtSamples2 <- sampleSubMat(m, 10L, TRUE))
#>     row1 col1 row2 col2
#>  1:    6    7    5    8
#>  2:    2   10    3    4
#>  3:    7    7   10    1
#>  4:    5    1    4    9
#>  5:    1    8    9    7
#>  6:   10    1    8    8
#>  7:    8    8    7    6
#>  8:    7   10    3    9
#>  9:    2   10    3    9
#> 10:    2    1    3    6
# timing 1k samples from a random 10k-square matrix with 1M elements
idx <- sample(1e8, 1e6)
m <- new("ngTMatrix", i = as.integer(((idx - 1) %% 1e4)), j = as.integer(((idx - 1) %/% 1e4)), Dim = c(1e4L, 1e4L))
system.time(dtSamples3 <- sampleSubMat(m, 1e3L)) # without replacement
#>    user  system elapsed 
#>    1.08    0.31    1.40
system.time(dtSamples4 <- sampleSubMat(m, 1e3L, TRUE)) # with replacement
#>    user  system elapsed 
#>    0.89    0.32    1.21
Created on 2022-04-29 by the reprex package (v2.0.1)