快速生成大规模随机ndarray的方法

Fast way to generate large-scale random ndarray

我想生成一个形状为 (1e7, 800) 的随机矩阵。但我发现 numpy.random.rand() 在这种规模下变得非常慢。有没有更快的方法?

一个简单的方法是使用 Numba 编写多线程实现:

import numba as nb
import random

@nb.njit('float64[:,:](int_, int_)', parallel=True)
def genRandom(n, m):
    res = np.empty((n, m))

    # Parallel loop
    for i in nb.prange(n):
        for j in range(m):
            res[i, j] = np.random.rand()

    return res

这比我的 6 核机器上的 np.random.rand()6.4 倍

请注意,使用 32 位浮点数 可能有助于加快计算速度,但精度会降低。

Numba 是一个不错的选择,另一个可能效果很好的选项是 dask.array,它将创建 numpy 数组的惰性块并在块上执行并行计算。在我的机器上,我的速度提高了 2 倍(对于 1e6 x 1e3 矩阵,因为我的机器上没有足够的内存)。

rows = 10**6
cols = 10**3
import dask.array as da
x = da.random.random(size=(rows, cols)).compute() # takes about 5 seconds

# import numpy as np
# x = np.random.rand(rows, cols) # takes about 10 seconds

请注意,最后的 .compute 只是将计算的数组带入内存,但通常您可以继续利用 dask 的并行计算来获得更快的计算(也可以扩展到超出单机),参见 docs.

尝试从目前给出的答案中寻找答案:

我刚刚编写了一个脚本,该脚本是根据已经给出的(SultanOrazbayev and Jérôme Richard)答案编译的,每个 numbadasknumpy 方法包含 3 个函数,并且测量 n 个不同大小的数组所花费的时间。

密码

import dask.array as da
import matplotlib.pyplot as plt
import numba as nb
import timeit
import numpy as np


@nb.njit('float64[:,:](int_, int_)', parallel=True)
def nmb(n, m):
    res = np.empty((n, m))

    # Parallel loop
    for i in nb.prange(n):
        for j in range(m):
            res[i, j] = np.random.rand()

    return res


def nmp(n, m):
    return np.random.random((n, m))


def dask(n, m):
    return da.random.random(size=(n, m)).compute()


if __name__ == '__main__':
    data = []
    for i in range(1, 16):
        dmm = 2 ** i
        s_nmb = timeit.default_timer()
        nmb(dmm, dmm)
        e_nmb = timeit.default_timer()

        s_nmp = timeit.default_timer()
        nmp(dmm, dmm)
        e_nmp = timeit.default_timer()

        s_dask = timeit.default_timer()
        dask(dmm, dmm)
        e_dask = timeit.default_timer()

        data.append([
            dmm,
            e_nmb - s_nmb,
            e_nmp - s_nmp,
            e_dask - s_dask
        ])

    data = np.array(data)

    plt.plot(data[:, 0], data[:, 1], "-r", label="Numba")
    plt.plot(data[:, 0], data[:, 2], "-g", label="Numpy")
    plt.plot(data[:, 0], data[:, 3], "-b", label="Dask")

    plt.xlabel("Number of Element on each axes")
    plt.ylabel("Time spent (s)")
    plt.legend()
    plt.show()

结果