将 numpy.bmat 与 numba 一起使用

Using numpy.bmat with numba

我正在尝试在我的 numba 优化 python 程序中使用 np.bmat。为此,我必须手动定义一个 jitted 函数 bmat,因为不支持来自 numpy 的原生函数:

@njit
def _bmat_2d(matrices):
    arr_rows = []
    for row in matrices:
        arr_rows.append(np.concatenate(row, axis=-1))
    return np.array(np.concatenate(arr_rows, axis=0))

(此代码或多或少是 numpy 代码的简化副本)

但是:

  1. numba 只接受 np.concatenate [1]
  2. 输入中的元组
  3. numba 非常不擅长将任意列表转换为元组 [2]

你有什么想法吗?

参考文献:

以下是否适合您的目的?

import numpy as np
import numba as nb

@nb.njit
def _bmat_2d(m):
    out = np.hstack(m[0])
    for row in m[1:]:
        x = np.hstack(row)
        out = np.vstack((out, x))

    return out

A = np.random.randint(10, size=(3,2))
B = np.random.randint(10, size=(3,1))
C = np.random.randint(10, size=(3,3))
D = np.random.randint(10, size=(4,6))

a = np.bmat(((A, B, C), (D,)))
b = _bmat_2d(((A, B, C), (D,)))

print(np.allclose((a, b))  # True

请注意,您必须传入元组的元组,而不是列表的列表,否则您将收到 "reflected list" 错误,因为当前版本的 Numba 无法处理列表-列表。

由于 np.hstacknumba 不兼容,我不得不编写自己的解决方案。也许你们中的一些人觉得这很有用。它不漂亮,但它完成了工作。

这基本上和 J = np.bmat([[J_1, J_2], [J_3, J_4]]) 做同样的事情。

只需确保更改 J = np.zeros((8, len(J_1[0])*2)) 以适合您想要的输出数组:

import numpy as np
import numba

@numba.njit
def main():
    J_1 = np.array([[-64., 25.6, 25.6, 12.8], [25.6, -25.6, 0., 0.], [25.6, 0., -25.6, 0.], [12.8, 0., 0., -652.8]])
    J_2 = np.array([[-85.33333333, 34.13333333, 34.13333333, 17.06666667], [34.13333333, -34.13333333, 0., 0.], [34.13333333, 0., -34.13333333, 0.], [17.06666667, 0., 0., -870.4]])
    J_3 = np.array([[85.33333333, -34.13333333, -34.13333333, -17.06666667], [-34.13333333, 34.13333333, -0., -0.], [-34.13333333, -0., 34.13333333, -0.], [-17.06666667, -0., -0., 870.4]])
    J_4 = np.array([[-64., 25.6, 25.6, 12.8], [25.6, -25.6, 0., 0.], [25.6, 0., -25.6, 0.], [12.8, 0., 0., -652.8]])

    J = np.zeros((8, len(J_1[0])*2))
    for idx, _ in enumerate(J_1[0]):
        J[0][idx], J[1][idx], J[2][idx], J[3][idx], J[4][idx], J[5][idx], J[6][idx], J[7][idx] = J_1[0][idx], J_1[1][idx], J_1[2][idx], J_1[3][idx], J_3[0][idx], J_3[1][idx], J_3[2][idx], J_3[3][idx]
        J[0][idx+len(J_1[0])], J[1][idx+len(J_1[0])], J[2][idx+len(J_1[0])], J[3][idx+len(J_1[0])], J[4][idx+len(J_1[0])], J[5][idx+len(J_1[0])], J[6][idx+len(J_1[0])], J[7][idx+len(J_1[0])] = J_2[0][idx], J_2[1][idx], J_2[2][idx], J_2[3][idx], J_4[0][idx], J_4[1][idx], J_4[2][idx], J_4[3][idx]

    print(J)

if __name__ == '__main__':
    main()

编辑:

一个人在另一个线程上用这个简单的替换 np.bmat 来帮助我 numba.njit:

J = np.vstack((np.hstack((J_1, J_2)), np.hstack((J_3, J_4))))