使用 cblas_sgemm 执行复杂的矩阵操作运算以执行乘法
Performing complicated matrix manipulation operations with cblas_sgemm in order to carry out multiplication
我有 100 个 3x3x3
矩阵,我想将它们与另一个大小为 3x5x5
的大矩阵相乘(类似于用多个过滤器对一个图像进行卷积,但不完全是)。
为了便于解释,这是我的大矩阵的样子:
>>> x = np.arange(75).reshape(3, 5, 5)
>>> x
array([[[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19],
[20, 21, 22, 23, 24]],
[[25, 26, 27, 28, 29],
[30, 31, 32, 33, 34],
[35, 36, 37, 38, 39],
[40, 41, 42, 43, 44],
[45, 46, 47, 48, 49]],
[[50, 51, 52, 53, 54],
[55, 56, 57, 58, 59],
[60, 61, 62, 63, 64],
[65, 66, 67, 68, 69],
[70, 71, 72, 73, 74]]])
在内存中,我假设大矩阵中的所有子矩阵都存储在连续的位置(如果我错了请纠正我)。我想做的是,从这个 3x5x5
矩阵中,我想从大矩阵的每个子矩阵中提取 3 5x3
列,然后将它们水平连接以获得 5x9
矩阵(如果这部分不清楚,我深表歉意,如果需要我可以更详细地解释)。如果我使用 numpy,我会这样做:
>>> k = np.hstack(np.vstack(x)[:, 0:3].reshape(3, 5, 3))
>>> k
array([[ 0, 1, 2, 25, 26, 27, 50, 51, 52],
[ 5, 6, 7, 30, 31, 32, 55, 56, 57],
[10, 11, 12, 35, 36, 37, 60, 61, 62],
[15, 16, 17, 40, 41, 42, 65, 66, 67],
[20, 21, 22, 45, 46, 47, 70, 71, 72]])
但是,我没有使用 python,所以我无法访问我需要的 numpy 函数,以便将数据块重塑为我想要执行乘法的形式...我只能在C中直接调用cblas_sgemm
函数(来自BLAS库),其中k
对应输入B。
这是我给 cblas_sgemm
的电话:
cblas_sgemm( CblasRowMajor, CblasNoTrans, CblasTrans,
100, 5, 9,
1.0,
A, 9,
B, 9, // this is actually wrong, since I don't know how to specify the right parameter
0.0,
result, 5);
基本上,ldb
属性是这里的罪魁祸首,因为我的数据没有按照我需要的方式被阻止。我尝试了不同的方法,但我无法 cblas_sgemm
理解我希望它如何读取和理解我的数据。
总之不知道怎么说cblas_sgemm
去读x
like k
。有没有一种方法可以在 python 中巧妙地重塑我的数据,然后再将其发送到 C,以便 cblas_sgemm
可以按照我想要的方式工作吗?
我会通过设置CblasTrans转置k
,所以在乘法时,B是9x5
。我的矩阵 A 的形状是 100x9
。希望对您有所帮助。
如有任何帮助,我们将不胜感激。谢谢!
In short, I don't know how to tell cblas_sgemm to read x like k.
你不能。你必须复印一份。
考虑 k
:
In [20]: k
Out[20]:
array([[ 0, 1, 2, 25, 26, 27, 50, 51, 52],
[ 5, 6, 7, 30, 31, 32, 55, 56, 57],
[10, 11, 12, 35, 36, 37, 60, 61, 62],
[15, 16, 17, 40, 41, 42, 65, 66, 67],
[20, 21, 22, 45, 46, 47, 70, 71, 72]])
在二维数组中,内存中元素的间距在每个轴上必须相同。您从 x
的创建方式知道内存中的连续元素是 0, 1, 2, 3, 4, ...
,但是 k
的第一行包含 0, 1, 2, 25, 26, ....
。 1
和 2
之间没有间距(即内存地址增加数组的一个元素的大小),但是 2
和 [= 之间内存有很大的跳跃22=]。所以你必须制作一个副本才能创建 k
.
话虽如此,还有一种替代方法可以使用一些重塑(无需复制)和 numpy 的 einsum
函数来有效地实现您想要的最终结果。
这是一个例子。首先定义x
和A
:
In [52]: x = np.arange(75).reshape(3, 5, 5)
In [53]: A = np.arange(90).reshape(10, 9)
这是我对你想要达到的目标的理解; A.dot(k.T)
是想要的结果:
In [54]: k = np.hstack(np.vstack(x)[:, 0:3].reshape(3, 5, 3))
In [55]: A.dot(k.T)
Out[55]:
array([[ 1392, 1572, 1752, 1932, 2112],
[ 3498, 4083, 4668, 5253, 5838],
[ 5604, 6594, 7584, 8574, 9564],
[ 7710, 9105, 10500, 11895, 13290],
[ 9816, 11616, 13416, 15216, 17016],
[11922, 14127, 16332, 18537, 20742],
[14028, 16638, 19248, 21858, 24468],
[16134, 19149, 22164, 25179, 28194],
[18240, 21660, 25080, 28500, 31920],
[20346, 24171, 27996, 31821, 35646]])
以下是通过切片 x
和重塑 A
获得相同结果的方法:
In [56]: x2 = x[:,:,:3]
In [57]: A2 = A.reshape(-1, 3, 3)
In [58]: einsum('ijk,jlk', A2, x2)
Out[58]:
array([[ 1392, 1572, 1752, 1932, 2112],
[ 3498, 4083, 4668, 5253, 5838],
[ 5604, 6594, 7584, 8574, 9564],
[ 7710, 9105, 10500, 11895, 13290],
[ 9816, 11616, 13416, 15216, 17016],
[11922, 14127, 16332, 18537, 20742],
[14028, 16638, 19248, 21858, 24468],
[16134, 19149, 22164, 25179, 28194],
[18240, 21660, 25080, 28500, 31920],
[20346, 24171, 27996, 31821, 35646]])
我有 100 个 3x3x3
矩阵,我想将它们与另一个大小为 3x5x5
的大矩阵相乘(类似于用多个过滤器对一个图像进行卷积,但不完全是)。
为了便于解释,这是我的大矩阵的样子:
>>> x = np.arange(75).reshape(3, 5, 5)
>>> x
array([[[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19],
[20, 21, 22, 23, 24]],
[[25, 26, 27, 28, 29],
[30, 31, 32, 33, 34],
[35, 36, 37, 38, 39],
[40, 41, 42, 43, 44],
[45, 46, 47, 48, 49]],
[[50, 51, 52, 53, 54],
[55, 56, 57, 58, 59],
[60, 61, 62, 63, 64],
[65, 66, 67, 68, 69],
[70, 71, 72, 73, 74]]])
在内存中,我假设大矩阵中的所有子矩阵都存储在连续的位置(如果我错了请纠正我)。我想做的是,从这个 3x5x5
矩阵中,我想从大矩阵的每个子矩阵中提取 3 5x3
列,然后将它们水平连接以获得 5x9
矩阵(如果这部分不清楚,我深表歉意,如果需要我可以更详细地解释)。如果我使用 numpy,我会这样做:
>>> k = np.hstack(np.vstack(x)[:, 0:3].reshape(3, 5, 3))
>>> k
array([[ 0, 1, 2, 25, 26, 27, 50, 51, 52],
[ 5, 6, 7, 30, 31, 32, 55, 56, 57],
[10, 11, 12, 35, 36, 37, 60, 61, 62],
[15, 16, 17, 40, 41, 42, 65, 66, 67],
[20, 21, 22, 45, 46, 47, 70, 71, 72]])
但是,我没有使用 python,所以我无法访问我需要的 numpy 函数,以便将数据块重塑为我想要执行乘法的形式...我只能在C中直接调用cblas_sgemm
函数(来自BLAS库),其中k
对应输入B。
这是我给 cblas_sgemm
的电话:
cblas_sgemm( CblasRowMajor, CblasNoTrans, CblasTrans,
100, 5, 9,
1.0,
A, 9,
B, 9, // this is actually wrong, since I don't know how to specify the right parameter
0.0,
result, 5);
基本上,ldb
属性是这里的罪魁祸首,因为我的数据没有按照我需要的方式被阻止。我尝试了不同的方法,但我无法 cblas_sgemm
理解我希望它如何读取和理解我的数据。
总之不知道怎么说cblas_sgemm
去读x
like k
。有没有一种方法可以在 python 中巧妙地重塑我的数据,然后再将其发送到 C,以便 cblas_sgemm
可以按照我想要的方式工作吗?
我会通过设置CblasTrans转置k
,所以在乘法时,B是9x5
。我的矩阵 A 的形状是 100x9
。希望对您有所帮助。
如有任何帮助,我们将不胜感激。谢谢!
In short, I don't know how to tell cblas_sgemm to read x like k.
你不能。你必须复印一份。
考虑 k
:
In [20]: k
Out[20]:
array([[ 0, 1, 2, 25, 26, 27, 50, 51, 52],
[ 5, 6, 7, 30, 31, 32, 55, 56, 57],
[10, 11, 12, 35, 36, 37, 60, 61, 62],
[15, 16, 17, 40, 41, 42, 65, 66, 67],
[20, 21, 22, 45, 46, 47, 70, 71, 72]])
在二维数组中,内存中元素的间距在每个轴上必须相同。您从 x
的创建方式知道内存中的连续元素是 0, 1, 2, 3, 4, ...
,但是 k
的第一行包含 0, 1, 2, 25, 26, ....
。 1
和 2
之间没有间距(即内存地址增加数组的一个元素的大小),但是 2
和 [= 之间内存有很大的跳跃22=]。所以你必须制作一个副本才能创建 k
.
话虽如此,还有一种替代方法可以使用一些重塑(无需复制)和 numpy 的 einsum
函数来有效地实现您想要的最终结果。
这是一个例子。首先定义x
和A
:
In [52]: x = np.arange(75).reshape(3, 5, 5)
In [53]: A = np.arange(90).reshape(10, 9)
这是我对你想要达到的目标的理解; A.dot(k.T)
是想要的结果:
In [54]: k = np.hstack(np.vstack(x)[:, 0:3].reshape(3, 5, 3))
In [55]: A.dot(k.T)
Out[55]:
array([[ 1392, 1572, 1752, 1932, 2112],
[ 3498, 4083, 4668, 5253, 5838],
[ 5604, 6594, 7584, 8574, 9564],
[ 7710, 9105, 10500, 11895, 13290],
[ 9816, 11616, 13416, 15216, 17016],
[11922, 14127, 16332, 18537, 20742],
[14028, 16638, 19248, 21858, 24468],
[16134, 19149, 22164, 25179, 28194],
[18240, 21660, 25080, 28500, 31920],
[20346, 24171, 27996, 31821, 35646]])
以下是通过切片 x
和重塑 A
获得相同结果的方法:
In [56]: x2 = x[:,:,:3]
In [57]: A2 = A.reshape(-1, 3, 3)
In [58]: einsum('ijk,jlk', A2, x2)
Out[58]:
array([[ 1392, 1572, 1752, 1932, 2112],
[ 3498, 4083, 4668, 5253, 5838],
[ 5604, 6594, 7584, 8574, 9564],
[ 7710, 9105, 10500, 11895, 13290],
[ 9816, 11616, 13416, 15216, 17016],
[11922, 14127, 16332, 18537, 20742],
[14028, 16638, 19248, 21858, 24468],
[16134, 19149, 22164, 25179, 28194],
[18240, 21660, 25080, 28500, 31920],
[20346, 24171, 27996, 31821, 35646]])