random.multivariate_normal 在 dask 阵列上?

random.multivariate_normal on a dask array?

我一直在努力寻找一种方法来获得适用于简单工作流程的计算器。

我有使用 np.random.mulivariate_normal 函数的代码,虽然我们可以在 dask 数组上使用其中的许多类型,但似乎这个不行。 Sooo....我试图根据 dask documentation.

中提供的示例创建自己的示例

这是我的尝试,它给出了我难以理解的错误。我还提供了随机输入变量以使其易于复制:

import numpy as np
from dask.distributed import Client
import dask.array as da

def mvn(mu, sigma, n, blocksize):
    chunks = ((blocksize,) * (n // blocksize),
              (blocksize,) * (n // blocksize))

    name = 'mvn'   # unique identifier

    dsk = {(name, i, j): (np.random.multivariate_normal(mu,sigma, blocksize))
                         if i == j else
                         (np.zeros, (blocksize, blocksize))
             for i in range(n // blocksize)
             for j in range(n // blocksize)}

    dtype = np.random.multivariate_normal(0).dtype  # take dtype default from numpy

    return da.Array(dsk, name, chunks, dtype)

n = 10000
A = da.random.normal(0, 1, size=(n,n), chunks=(1000, 1000))
sigma = da.dot(A,A.transpose())
mu = 4.0*da.ones(n, chunks = 1000)
R =  da.numpy.random.mvn(mu, sigma, n, chunks=(100))

有什么建议吗,或者我是否离题太远以至于我应该放弃所有希望?谢谢!

这个答案可以充实,但我想你会更容易使用 dask 的 delayedda.from_delayedda.*stack

我看到你所拥有的一个直接问题:np.random.multivariate_normal(mu,sigma, blocksize) 你直接调用函数,而不是制定规范。您可能想要 (np.random.multivariate_normal, mu,sigma, blocksize)。这表明使用原始 dask 字典可能很棘手!

如果你有一个 运行 这个集群,你可以使用我在 的回答,复制到这里以供参考:

目前的一个解决方法是使用 cholesky 分解。请注意,任何协方差矩阵 C 都可以表示为 C=G*G'。然后,如果 y 是标准正态分布,则 x = G'*y 与 C 中指定的相关(参见 StackExchange Mathematic 上的 excellent post)。在代码中:

Numpy

n_dim =4
size = 100000
A = np.random.randn(n_dim, n_dim)
covm = A.dot(A.T)

x=  np.random.multivariate_normal(size=size, mean=np.zeros(len(covm)),cov=covm)
## verify numpys covariance is correct
np.cov(x, rowvar=False)
covm

达斯克

## create covariance matrix
A = da.random.standard_normal(size=(n_dim, n_dim),chunks=(2,2))
covm = A.dot(A.T)

## get cholesky decomp
L = da.linalg.cholesky(covm, lower=True)

## drawn standard normal 
sn= da.random.standard_normal(size=(size, n_dim),chunks=(100,100))

## correct for correlation
x =L.dot(sn.T)
x.shape

## verify
covm.compute()
da.cov(x, rowvar=True).compute()